Active Dendrites

Avoiding catastrophic forgetting

Photo by Henry Be

The following content is mainly about the article Avoiding Catastrophe: Active Dendrites Enable Multi-Tasking Learning in Dynamic Environments by A. Iyer et al. (December 2021). It is a pleasant paper mixing biology, neuroscience and mathematical modeling, I hope you find it interesting.

Catastrophic forgetting

Standard Artificial Neural Networks (ANNs), based on the (inaccurate) point neuron model [Lapique, 1907] and backpropagation algorithm, often fail dramatically in multiple task learning. Differently from single-task machine learning, learning multiple distinct tasks introduces new complications. When using gradient-based methods (such as backpropagation), a noteworthy issue is that error gradients and accumulated knowledge from different tasks can interfere with one another. Effective weight tweaking to reduce the error for one task may lead to suboptimal or ruinous performance for another task. This is a common problem known as catastrophic forgetting.

The same is true for continual learning, that concerns the ability to acquire new knowledge over time while retaining relevant information from the past. A typical scenario involves training a network on a set of distinct tasks presented in a strict sequence of training phases. As a basic example, consider two different learning tasks: (1) classify dogs type and (2) identify Aramaic alphabet letters.

In essence, learning is starting from an initial weights configuration and then moving throughout weight space to a place where the error is small on the task being learned.

The figure above provides an intuitive support for what happens. Consider the sequential learning of the two aforementioned tasks 1 and 2. From the initial weights configuration (yellow dot), after learning to classify dogs we reach a certain minimum region (a). Then we learn to indentify letters and the weight configuration is modified to reach a minimum region (b). So network has completely ignored which weight configuration is appropriate for the first task.

Biological neurons and Active Dendrites

The point neuron model postulates that all of neural synapses have a linear impact on the cell. This simple assumption laid the foundations of Rosenblatt’s original Perceptron [Rosenblatt, 1958] and continues to form the basis for current deep learning networks.

This artificial neuron has relatively few synapses and no dendrites. Learning occurs by changing the strength or “weight” of the synapses which are represented by a scalar value that can take positive or negative values. A weighted sum of point neuron inputs is calculated and then a non-linear function f determines the output value of the neuron. It is now well known that the point neuron assumption is an oversimplified model of biological computations.

Pyramidal neurons (see figure below) are the most common type of neurons in the neocortex. Biological neurons have thousands of synapses arranged along dendrites. Biological synapses are partly stochastic, and therefore are low precision. Learning in a biological neuron mostly involves the formation of new synapses and the removal of unused synapses.

In real neurons, proximal synapses (those close to the cell body) have a linear impact on the neuron, but the most of synapses occur on distal dendritic segments (away from the cell body). These distal segments are known as active dendrites and process synapses in a non-linear fashion. When input to an active dendritic segment reaches a threshold, the segment initiates a dendritic spike that travels to the cell body and can determine a depolarization of the neuron for an extended period of time, even for half a second. During this period, the neuron is closer to its firing threshold and any new input is more likely to make the neuron fire. Hence, these dendrites — differently from proximal segments — have a modulatory and long-lasting impact on the neuron’s activity. Any active dendritic segment receives input signal from cells in different layers or in the form of top-down feedback.

Sparse Representations

Neural circuits in the neocortex are highly sparse. Studies reveal that relatively few neurons spike in response to a sensory stimulus. Neural connectivity is also sparse: pyramidal neurons are sparsely connected to each other and receive relatively few signals from neighboring neurons.

This is not the case in neural network modeling, where connections are mostly dense. Sparse neural representations are introduced using vectors where most of the entries are zero. Studies show that sparse representations are more resistant to noise than the dense ones. Furthermore, pattern recognition is less prone to negative effects due to slight perturbations in the input.

Active Dendrites Neuron

The authors propose a new neuron model. Mimicking what happens in pyramidal neurons, the active dendrites neuron receives two sources of input, in analogy with the proximal and distal inputs. Feedforward input is treated exactly like a point neuron. At the same time, multiple dendritic segments process a context vector and their output modulates the feedforward activation. In other words, the magnitude of the response to a given stimulus is highly context-dependent. The image below shows five dendrites processing context (weights involved are represented by small discs) and the feedforward input.

Given input x, weights w and bias b, the feedforward signal is, as usual, computed as

\hat{t} = \mathbf{w}^\top \mathbf{x} + b \,.

Note that weights here do not represent a 2d matrix but a vector including just the values involved with the particular neuron (needles to say, we are referring to a single artificial neuron whose functioning we are defining). On the other hand, each dendrite j computes

\mathbf{u}_j^\top \mathbf{c}

where uj are weights relative to j -th dendrite and c is a context vector (for example, the context vector may encode task ID info). We will not delve too deeply into the question of calculating such a context vector but, in short, the context vector:

1) is computed using prototype representations for different classes;

2) if the system receives task information during training, then the prototype vector for a certain task is computed by taking the element-wise mean over all the training samples across all features;

3) if the system receives no task information during training, then a statistical clustering approach is used: if the new batch of samples is similar to earlier training samples, they are assigned to an existing prototype; if not, the new batch of samples is assumed to correspond to a new task, and a novel prototype is instantiated.

The figure above illustrates the prototype method. Yellow points represent samples for task A, beige for task B.

Returning to our neuron model, the segment with the strongest response to the context is selected:

\displaystyle d = \max_j \mathbf{u}_j^\top \mathbf{c}\, .

The active dendrites contextual contribution modulates the feedforward activation in the following manner:

\displaystyle y= f(\hat{t},d) = \hat{t} \cdot \sigma(d)\,.

In the expression above, y is the resulting activation, σ is the sigmoid function which takes a real number and maps it into the range [0, 1]. It is clear that weak responses (near zero) to the context vector will significantly reduce the resulting activation.

Modeling sparsity

To add sparsity in active dendrites neuron architectures, authors apply the kWTA (k-Winner-Take-All) function, that mimics biological inhibitory networks, defined as follows:

{k(y_i) = \begin{cases} y_i & \textsf{if}\; y_i\; \textsf{is one of the top}\, k \, \textsf{activations over all} \, i\\ 0 & \textsf{otherwise}\end{cases}}

where i indexes neurons in the same layer. Sparsity is ensured by selecting the top k activations and setting all others to zero.

Active Dendrites Network Architecture

The figure below shows an active dendrite neurons network. All neurons in each hidden layer are active dendrites neurons. The network is trained by backpropagation.

The neurons selected by the kWTA function are the only having nonzero activations (hence nonzero gradients) and these latter neurons will be the only ones to be updated during the backpropagation algorithm backward pass.

A very small sparse subset of the full network is actually updated for each input. This is because for each of those “winner” neurons, only the dendritic segment that was chosen by the max operator is updated (the other segments are not modified).

What do we expect from this model? Different dendritic inputs are expected to activate different subnetworks. If this happened, the backpropagation algorithm would only modify the connections of the neurons in each subnetwork, leaving the rest of the connections in the whole network untouched (see figure below).

From tests carried out on the permuted MNIST dataset, empirical evidence shows that the network does indeed invoke separate subsets of neurons to learn different tasks. As for the results, authors claim that — in the multi-task RL setting — a 3-layer active dendrites network can achieve an average accuracy of about 88% when learning 10 Meta-World environment tasks together, while — in the continual learning setting — an almost identical network can achieve greater than 90% accuracy when learning 100 permuted MNIST tasks in sequence.

Useful links

A. Iyer, K. Grewal, A. Velu, L. O. Souza, J. Forest, S. Ahmad
Avoiding Catastrophe: Active Dendrites Enable Multi-Task Learning in Dynamic Environments
arXiv:2201.00042v1 [cs.NE], 2021.

L. Lapique’s 1907 paper (translated, 2007).

Why Neural Networks Forget, and Lessons from the Brain [link].

J. Snell, K. Swersky, R. S. Zemel
Prototypical networks for few-shot learning
arXiv:1703.05175v2 [cs.LG], 2017.

Permuted MNIST [link].

T. Yu, D. Quillen, Z. He, R. Julian, A. Narayan, H. Shively, A. Bellathur, K. Hausman, C. Finn, S. Levine
Meta-World: A Benchmark and Evaluation for Multi-Task and Meta Reinforcement Learning
arXiv:1910.10897v2 [cs.LG], 2019 (v2 revised 2021).

Leave a comment