A dubious successor to Transformer
AI models have come a long way. But , thinking about it, at the level of elementary architecture components we are stuck in 2017, i.e. the advent of the Transformer. There are a lot of papers and a lot of brilliant ideas, but none working… AI prefers simple constructs, sometimes such as Gramian matrices to express affinities between objects. It’s crazy but that’s how it is.
Here comes Retentive Network (RetNet), presented as the successor of the Transformer architecture. Revolution or yet another rip-off?
The impossible Triangle
In the paper “Retentive Network: A Successor to Transformer for Large Language Models” by Y. Sun et al., authors introduce RetNet as a “foundation architecture for large language models, simultaneously achieving training parallelism, low-cost inference, and good performance”. How?
Transformers have high cost inference, Linear Transformers replace the standard attention scores with linear kernels obtaining low cost inference but the modeling capability and performance are worse than Transformers, Recurrent Neural Networks (RNNs) are just the past… unless they come back with some twist making them interesting.
Authors claim that RetNets achieve low-cost inference, efficient long-sequence modeling, Transformer-comparable performance, and parallel model training simultaneously. The main expedients allowing this are the following:
(a) a multi-scale retention mechanism to substitute multi-head attention and
(b) three computation paradigms for this mechanism, i.e., parallel, recurrent, and chunkwise recurrent representations.
According to the authors, the parallel representation empowers training parallelism to utilize GPU devices fully, the recurrent representation enables efficient O(1) inference in terms of memory and computation and the chunkwise recurrent representation can perform efficient long-sequence modeling.
Each local block is encoded with the parallel mechanism for computation speed; in the meantime, the global blocks are recurrently encoded to save GPU memory.
The crux of retention mechanism is to avoid softmax operation in Transformer, which cannot be computed recurrently without simplifications – you need a quadratic amount of values right away.
Overall architecture
Input sequence
is encoded autoregressively into vector sequence
and then packed into
that is something you can view like a matrix |x| × d, that is sequence length × embedding dimension (think of each x_ as a row).
RetNet architecture consists of L stacked blocks. Each block is composed by two main components: a Multi-Scale Retention module (MSR) and a Fast Forward Network (FFN). Before each component, Layer Norm (LN) is applied. Each block has two skip connections, architecture is depicted below.
Retention
To reveal the details of MSR, it is necessary to introduce the retention mechanism. Given input X as above (think it as a |x| by d matrix), we calculate elementary values
and we try to establish a mapping between vn and output on through states sn, via a recursive procedure:
In first equation vn is mapped to the state vector sn, and then a linear transform is implemented to encode sequence information recurrently. Next, we make the projection Qn , Kn content-aware through the use of weight matrices for queries and values
where WQ and WK are learnable d × d, matrices. Then we diagonalize matrix A (guess that authors originally assumed that A was diagonalizable)
where γ and ϴ are two d-vectors. The geometric meaning of the complex exponential function will allow to add positional information later. Indeed, the transformation provides a rotation on vectors based on the angle ϴ.
Matrix diagonalization is useful in case one wants to evaluate a matrix power, in fact
because all the products of Λ matrices inside the second line result in a unit matrix. The matrix Λ and its inverse can be absorbed respectively into WQ and WK , so we can rewrite the equation for on as follows
where the highlighted quantities between brackets represent a a relative position embedding proposed for Transformer known as xPos, a transformation providing position encoding by rotating vectors. The above equation can be further simplified replacing the vector γ (a d components vector) with a single scalar value indicated with the same symbol γ
where † represents the conjugate transpose. This formulation is easily parallelizable within training instances.
Parallel representation
The parallel representation of retention is used for training. Look at the following architecture.
We see Q = (XWQ) ⊙ ϴ, K = (XWK) ⊙ ϴ and V = XWV , where ϴ represents the “rotational” contribution exp(i·n·ϴ). Note that ϴ is the complex conjugate of ϴ. GN is short for GroupNorm. Retention of X is given by the expression
with matrix D combining a causal mask (zeros entries in upper right position) and decay along relative distance (gamma entries). This matrix simulates what in the Transformer was done by softmax, in particular, vectors are weighted with an exponentially decaying factor, so that past tokens are less important for the current time step.
Recurrent representation
The retention mechanism can also be written as Recurrent Neural Networks, which is favorable for inference. The recurrent representation of retention merges aspects of RNNs with those of Transformer architecture. Q, K, V and γ are the same as previous descriptions. Let n denote the n-th timestep (n = 1, . . . , |x|). Then
The mechanism is illustrated below.
The times product is the outer product of the transpose of Kn with Vn (a matrix) and the asterisk product denotes the component-wise product of the matrix Sn with transposed Qn and then summed along each column (however, not all implementations include this sum in the recurrent retention function). This is best seen from the pseudocode (below) in the original article.
def RecurrentRetention( q, k, v, # bsz ∗ num_head ∗ len ∗ qkv_dim past_kv, # bsz ∗ num_head ∗ qk_dim ∗ v_dim decay # num_head ∗ 1 ∗ 1 ): current_kv = decay * past_kv + \ k.unsqueeze(-1)* v.unsqueeze(-2) output = torch.sum(q.unsqueeze(-1) \ * current_kv, dim=-2) output = group_norm(output) return output, current_kv
Chunkwise representation
There is even a hybrid representation mixing parallel and recurrent mechanisms. This representation should be fine for speeding up training, especially for long sequences. Divide the input sequences into chunks of length B
0 : B
B : 2B
2B : 3B
…
i·B : (i+1)B
…
and express the queries, keys and values as follows
with the adjusted recurrent relation
where [i] indicates the i-th chunk. Chunkwise retention
is composed of two pieces, one is the Inner-Chunk term, which follows the parallel representation, and the other is Cross-Chunk term following the recurrent representation.
Last piece: Gated Multi-Scale Retention
It’s time to define the multi-scale retention (MSR) layer seen at the beginning. Let r be the head dimension. Authors set the number h of retention heads as d/r. MSR assigns different γ for each head. The value γ is identical among different layers for simplicity. The swish gate is used to increase the non-linearity of retention layers. So, given input X, we define the layer as:
where WG and WO are learnable d × d matrices. The heads use multiple “discount factor” γ scales; depending on different values in the range 0 ~ h, γ becomes smaller. Retention layer uses Sub-LayerNorm to normalize outputs. The multi-scale modeling leads to different variances for the heads, this is why LayerNorm is replaced with GroupNorm.
Conclusions
What about performance? Well, this is another paper full of (supposedly) excellent results. Surely time will tell us if it is a real revolution or just another clumsy attempt to outdo the Transformer architecture. Everything seems so linear and recurrent, perhaps too much to get the expected results?
Useful links
Retentive Network: A Successor to Transformer for Large Language Models
Y. Sun, L. Dong, S. Huang, S. Ma, Y. Xia, J. Xue, J. Wang, F. Wei
arXiv:2307.08621 [cs.CL] (2023)
Official implementation on GitHub (link)
PyTorch implementation of RetNet by Jamie Stirling (link)
xPos paper (link)