Neural Interpreters

Sparse attention mechanisms and several analogies with programming codes


Photo by Shubham Dhage

For this post we refer to the paper “Dynamic Inference with Neural Interpreters” by Rahaman et al. (2021).

Overview

neural interpreter is a collection of modules almost resembling a programming code: it is a bunch of scripts which are made up of functions which are made up of lines of code. Essentially, this is an attention-based network and inputs to the model are routed through a sequence of functions in a way that is end-to-end learned.

Convolutional networks reuse computational units, like filters, laterally (once depth is fixed), meanwhile recurrent neural networks only reuse computational units (RNN cells) vertically, i.e., in depth. Such rigidity in the way networks reuse their units is believed to be one of the reasons for the poor generalization. Neural interpreter model aims to be an architecture made of independent and composable pieces, capable of relaxing this rigidity in computation reuse.

Input and Output

Assume that the input set contains vector embeddings of image patches or entire images. These elements are vectors of a certain dimension din. The input set additionally includes one or more learned vectors, called CLS tokens, for which the corresponding outputs interface with their respective classifiers. The output is another set of vectors whose dimension is dout (with the same cardinality as the input set).

Fig. 1. Neural Interpreter

Scripts

A neural interpreter is a stack of nₛ scripts mapping one set of vectors = {x₁ , x₂, …} to another Y = {y₁ , y₂, …} with the same number of elements:

\mathbf{Y} = \mathsf{Neural \; Interpreter}(\mathbf{X}) = \left[ \mathsf{Script}_{n_s} \, \circ \, \cdots \, \circ\, \mathsf{Script}_1 \right](\mathbf{X})

Fig. 2. A neural interpreter is a stack of scripts

Increasing the number of scripts nₛ will increase the depth of the architecture. A script has four components:

  1. a type inference module;
  2. a type matching mechanism;
  3. a set of functions;
  4. an interpreter.

We will soon describe these four components.

Functions

Each script contains functions. Functions are vector-valued instructions to other components in the script. Formally, a function fᵤ is a pair (sc) where s is called signature and c is called code (u is used as index). The signature is a normalized vector of dimensions dtype and indicates to the type matching mechanism (see below) what inputs are to be routed to fᵤ (note the analogy with coding)The vector c, a learned parameter for each function, is the code that tells the function what to do (further details in a moment). Each f has its own code that would always be the same.

Functions inside a script
Fig. 3. Functions inside a script

For example, f₁, f₂ and f₃ all share their global parameters but they all have their own codes. Samples can jump flexibly from one function to another. The way each sample is routed through the network is completely independent and it is determined on a per-sample basis. Every example has its own independent path to the network and the routing itself is completely learned.

Not all the examples are routed, so let’s see how an example gets to the functions’ scope.

Type Matching and Inference

Before getting to the functions, a sort of higher-level attention is performed on the set elements. Type matching is responsible for routing the information elements through functions. This is a three step procedure.

a) At the beginning, an input set element x is processed through an MLP module (called type inference module) to obtain a type vector t whose dimension is dtype. This vector lies in the same unit hypersphere 𝓣 containing the signature vectors s.

b) Consider a function fᵤ. Define a distance function based on the cosine similarity between t and signature s, that is d𝓣= 1 – sᵤ · t .

c) Successively, a sort of softmax with normalization is performed, returning a coefficient Cᵤᵢ. However, Cᵤᵢ is set to 0 if the distance between s and tᵢ is larger than τ, a value called truncation parameter. This introduces sparsity in the model. Fix u and i, then Cᵤᵢ is the compatibility between function fᵤ and set element xx can be processed by fᵤ only if Cᵤᵢ is sufficiently large. If Cᵤᵢ = 0, then fᵤ cannot access x.

Fig. 4. Type matching and inference

Modulated Linear layers and modulated MLPs

The following constructs are needed to define an attention mechanism later on. These constructs should be interpreted as programmable modules (the program is determined by the code c). Modulated linear layers act like linear layers with the only difference being that, instead of x, the linear transformation is applied to

x´ = x ⊗ LayerNorm(W𝒸 c)

where W𝒸 is a learnable matrix that constitutes a set of parameters shared among all functions in the same script (the symbol ⊗ denotes entry-wise product). In short

\mathbf{y} = \mathsf{ModLin}(\mathbf{x}; \mathbf{c} ) = \mathbf{W}\mathbf{x}^\prime + \mathbf{b}

where W is a weight matrix and b is a bias term. Having defined modulated linear layer, one may also stack L of them (sharing the same code c) interspersed with GELU activation functions to get the modulated MLP:

\begin{aligned} \mathbf{y} &= \mathsf{ModMLP}(\mathbf{x}; \mathbf{c} )\\ &= ( \mathsf{ModLin}_L(\bullet; \mathbf{c} ) \, \circ \, \mathsf{Activation} \, \cdots \, \circ \, \mathsf{ModLin}_1(\bullet; \mathbf{c} ) )(\mathbf{x})\,. \end{aligned}

ModAttn

A type of conditional (that is, conditioned by the code vector c of function fᵤmulti-head attention mechanism is used. Queries, keys and values are evaluated using ModLin layers (instead of simple linear layers) for each head h:

\begin{aligned} \mathbf{k}_{uhi} &= \mathsf{ModLin}_{\textsf{key}}^h(\mathbf{x}_i \,; \, \mathbf{c}_u )\\ \mathbf{q}_{uhi} &= \mathsf{ModLin}_{\textsf{query}}^h(\mathbf{x}_i \,; \, \mathbf{c}_u )\\ \mathbf{v}_{uhi} &= \mathsf{ModLin}_{\textsf{value}}^h(\mathbf{x}_i \,; \, \mathbf{c}_u ) \,. \end{aligned}

Then, consider again the compatibility coefficients {Cᵤᵢ}; these quantities would serve as modulators when evaluating self-attention weights. Self-attention weights are given by the normalizing expression

\displaystyle W_{uhij} = \frac{\tilde{W}_{uhij}}{\epsilon + \sum \tilde{W}_{uhij}}

where epsilon avoids divisions by ~0 terms and

\displaystyle \tilde{W}_{uhij} = C_{ui}C_{uj}\left[ \mathsf{softmax}_j \left(\frac{ \mathbf{q}_{uhi} \cdot \mathbf{k}_{uhj} }{\sqrt{d_\textsf{key}}}\right) \right] \,.

For example, fix fᵤ and the head h. Then we have

\displaystyle \tilde{W}_{ij} = C_{i}C_{j}\left[ \mathsf{softmax}_j \left(\frac{ \mathbf{q}_{i} \cdot \mathbf{k}_{j} }{\sqrt{d_\textsf{key}}}\right) \right]

and, after normalization, the weight Wij is the attention weight between elements x and x. Intuitively, information about x and x is mixed by fᵤ at head h only if Wuhij is not 0. This can happen in two cases: 1) the compatibility factors are both non-zero (that is, fᵤ can access both x and x) or 2) self-attention weights (the softmax part) is close to zero. Finally, the following linear combination is computed

\displaystyle \tilde{\mathbf{y}}_{uhi} = \sum_j W_{uhij}\mathbf{v}_{uhj}

and the final output is

\displaystyle \tilde{\mathbf{y}}_{ui} = \mathsf{ModLin}(\tilde{\mathbf{y}}_{ui;h}\,;\,\mathbf{c}_u)

where the semicolon separating h from ui indicates that the results of various heads are folded (as usual in multi-head attention) into one single object.

Line of Code

line of code layer is a ModAttn layer followed by a ModMLP layer (see figure below, on the right). Both these layers share the same condition vector and there are weighted residual connections between them.

Fig. 5. Lines of code

line of code (LOC) is a line of code layer applied in parallel streams, one per function, as shown in Fig. 5 (right). Inputs of a LOC, say {xᵤᵢ}, are written with an extra index u, meaning that this is a specific input to the function fᵤ. If a function fᵤ cannot access xᵤᵢ, then fᵤ acts on xᵤᵢ as the identity function. For example, focus on a particular function fᵤ and on its specific inputs {xᵤᵢ} as i vary. Then

\mathbf{a}_{ui} = \mathbf{x}_{ui} +C_{ui} \tilde{\mathbf{a}}_{ui}

where ãᵤᵢ is the output of the attention layer (ModAttn); then

\mathbf{y}_{ui} = \mathbf{a}_{ui} +C_{ui} \tilde{\mathbf{b}}_{ui}

where \mathbf{\tilde{b}}_{ui} is the output of the MLP module (essentially, a ModMLP layer). Note that if fᵤ cannot access xᵤᵢ (that is, Cᵤᵢ = 0), then the output yᵤᵢ is just xᵤᵢ.

Interpreter

The interpreter layer is a stack of LOCs sharing the same function codes. The interpreter broadcasts a given set element to multiple parallel computational streams, one for each function. Let the number of stacked LOCs be nₗ. Let = {x₁ , x₂, …} and C = {Cᵤ₁, C₂, …}, then

\mathbf{y}_{i} = \mathbf{x}_{i} + C_{1i}\, \mathcal{L}(\mathbf{X}, \mathbf{c}_1, \mathbf{C}_1) + C_{2i}\, \mathcal{L}(\mathbf{X}, \mathbf{c}_2, \mathbf{C}_2) + \cdots

where

\mathcal{L} = \underbrace{\mathsf{LOC}_{n_l}\,\circ\, \cdots\,\circ\,\mathsf{LOC}_1}_{n_l \, \textsf{times}} \,.

Essentially, the output is a weighted sum with compatibilies of the elements with the respective function as coefficients. Given a set of inputs and an instruction (the function code), the role of the interpreter is to execute that instruction and compute the output.

Increasing the number of LOCs nl increases the architecture depth and also the number of parameters.

Functions Iteration

We have already seen that the overall model is a stack of multiple scripts. A script can be expressed as a recurrent application of Function Iteration (FnIter)

\{ \mathbf{y}_1, \mathbf{y}_2, \dots\} =( \underbrace{\mathsf{FnIter}\,\circ\, \cdots\,\circ\,\mathsf{FnIter}}_{n_i \, \textsf{times}})( \{ \mathbf{x}_1, \mathbf{x}_2, \dots\})

where FnIter is defined as the composition of the type matching mechanism and the interpreter.

The number of function iterations nᵢ can increase without increasing the number of parameters, so FnIter can enable units sharing in depth.

Experiments

Some experiments have been conducted on subjects such as learning fuzzy boolean expressions, multi-task image classification abstract reasoning. However, we do not delve any further into such matters as it will only be time to determine whether this recent architecture is profitable or not.

Useful links

Original article on Neural Interpreters.

Nice discussion with authors (video).

Leave a comment