CoAtNets

A class of state-of-the-art computer vision models

Photo by Monisha Selvakumar

This post refers mainly to the paper CoAtNet: Marrying Convolution and Attention for All Data Sizes by Z. Dai et al. (2021).

CoAtNet models (pronounced “coat” net) for computer vision emerge as a combination of the Convolutional and Transformer (a Self-Attention based model) architectures. Experiments show that CoAtNets achieve state-of-the-art performance across various datasets like ImageNet and JFT-3B.

Convolution and Self-Attention

Convolutional neural networks (CNNs) use the convolution operation as follows (check here for a simple intro to convolution operation). Let x be a given input — think of an image or, more generally, a feature representation —  whose dimensions are r × c × d, where r and c are image (or representation) rows and columns, d is the number of channels. Let \mathcal{L}(i) be a local image patch around pixel xᵢ, where i denotes coordinates (α, β). Then the convolution output yᵢ is

\displaystyle y_i = \sum_{\mathcal{L}(i)} w_{i-j} \odot x_j

where ij = (α, β) – (m, n) and wᵢ is a weight (a convolution kernel entry). The index j = (m, n) varies over the patch \mathcal{L}(i) . Note that xᵢ can be also considered as a 1 × 1 × d, so the product might involve multiple channels. Below, an example with \mathcal{L}(i) = \mathcal{L} (3, 3), a local patch for representation x.

CNNs employ weight sharing: kernel matrix is reused for generating the output for all pixel positions (a, b). Weight sharing enforces translation equivariance

convolve(translate(x)) = translate(convolve(x))

and this is a fine property because you if your CNN detects a particular element in an image, it will find that element again when shifting the image.

For self-attention, consider a 1× 1 × d “pixel” xᵢ and consider a region \mathcal{G} whose center is, for simplicity, xᵢ. This is similar to a local image patch, but the letter \mathcal{G} tells us that this region can even be considered as global. Single-headed attention yᵢ is 

\displaystyle y_i = \sum_{j \in \mathcal{G}} \textsf{softmax}_\mathcal{G} \left( q_i^\top k_j \right) v_j

where the queries qᵢ = Qxᵢ, keys kⱼ = Kxⱼ and values vⱼ = Vxⱼ quantities are described here. The matrices Q, K and V are learned matrices. Softmax is applied in relation to the quantities computed from pixels in the neighborhood \mathcal{G} of xᵢ. The notation j \mathcal{G} indicates that the sum is over all indices j corresponding to elements (pixels) in \mathcal{G} . This computation is repeated for every pixel x₍ ₎ to obtain outputs y₍ ₎. In practice, multiple attention heads are used to learn distinct representations of the input. Below, an image showing what has just been described.

The dashed lines represent learned transformations, the rest are matrix operations. 

In the current setting, no positional information is encoded in attention. This poses a limit on the expressiveness of vision models. Information about position can be achieved through the well-known positional embeddings, using sinusoidal functions. However, many experiments suggest to use relative positional embeddings for better results. Relative attention is defined as follows. Consider the relative distance of pixel of coordinates i = (α, β) to each position j = (m, n) in \mathcal{G} so that each position determines two distances: a row offset mα and column offset nβ (see figure below). The relative distances are computed with respect to — for example —  pixel (0,0) and their format is row offset (yellow), columns offset (gray).

The row and column offsets are associated with an embedding r(mα) and r(nβ) respectively, each with dimension that is the half of the output dimension d(out). Concatenating these vectors to form a unique vector, the expression for this relative attention is

\displaystyle y_i = \sum_{j \in \mathcal{G}} \textsf{softmax}_\mathcal{G} \left( q_i^\top k_j + q_i^\top r_{j-i} \right)v_j\,.

So we have two components as argument of softmax: the logit expressing similarity between the query and an element from \mathcal{G} and the relative distance of the element from the query. Note that adding relative position information, self-attention also enjoys — similarly to convolutions — translation equivariance.

Merging desirable properties

It is worthwhile to compare the relative strengths and weaknesses of both convolution and self-attention, before questioning about how to best combine them.

Translation Equivariance. We saw earlier that this is a property satisfied by convolution.

Input-adaptive Weighting. In convolution, kernel entries are static and do not depend on the particular input. Instead, the attention weights (all the softmax parts) dynamically depend on the representation of the input.

Global Receptive Field. One of the most crucial differences between self-attention and convolution concerns the size of the receptive field. A larger receptive field, despite the high computational cost involved, provides more contextual information which could lead to higher model capacity.

An ideal model would combine the three previous properties. Taking these properties into account, the authors use the following attention mechanism for their model

\displaystyle y_i = \sum_{j \in \mathcal{G}} \frac{\exp(x_i^\top x_j + w_{i-j})}{\sum_{k \in \mathcal{G}} \exp(x_i^\top x_k + w_{i-k})}\;x_j

which is a kind of relative attention

\displaystyle\sum_{j \in \mathcal{G}} \textsf{softmax}_\mathcal{G} \left( x_i^\top x_j + w_{i-j} \right)x_j

where weights take the place of relative distances. Here \mathcal{G} indicates the global spatial space and, for each j, the weight w_{i-j} is a scalar (there are as many as the order of \mathcal{G} ).

CoAtNet model

In the case of global attention, the complexity is quadratic w.r.t. spatial size. So it is not always feasible to use self-attention in vision tasks. Applying the previously defined attention directly to raw images would result in an excessively slow computation due to the (usually) large number of pixels involved. Hence, the authors state three main options:

(A) perform some down-sampling to reduce the spatial size and employ the global relative attention after the feature map reaches manageable level;

(B) enforce local attention, which restricts the global receptive field \mathcal{G} in attention to a local field \mathcal{L} just like in convolution;

(C) replace the quadratic softmax attention with certain linear attention variant which only has a linear complexity w.r.t. the spatial size.

Some experiments suggest excluding options (B) and (C) and to focus on (A). There are many ways to reduce the image size, leading to different architectures. The model we show uses, as a first stage S0, a simple 2-layer convolutional Stem. This is followed by stage S1, employing MBConv blocks with squeeze-excitation (SE), as the spatial size is too large for global attention. From S2 through S4 it is possible to consider either the MBConv or the Transformer block, provided that convolution stages must appear before Transformer stages. This leads to 4 different settings: CCCC, CCCT, CCTT and CTTT, where C denotes Convolution and T denotes Transformer. Some experiments reveal that the proper configuration is CCTT. 

For both the MBConv (yellow) and the Transformer (white) blocks, transformations are of the kind

xx + Module(Norm(x))

where Module is MBConv, Self-Attention or FFN (FeedForward Network) and Norm corresponds to BatchNorm for MBConv and LayerNorm for Self-Attention and FFN. As activation function, Gaussian Error Linear Units (GELUs) is used in both MBConv and Transformer blocks.

Within each stage from S1 to S4, down-sampling is performed independently for both the residual branch and the identity branch.

In the Transformer block, the standard max pooling of stride 2 is directly applied to the input states of both branches of the self-attention module. A channel projection (for example, a 1 × 1 convolution) is applied to the identity branch to enlarge the hidden size. Hence, the module down-sampling can be represented as

xx + Proj(Pool(x)) + Attention(Pool(Norm(x))).

For the MBConv block, differently from standard MBConv block, the residual branch down-sampling is obtained by using a stride=2 convolution to the normalized inputs. The standard MBConv uses stride=2 for the Depth-wise convolution part. We can express the module as follows:

xProj(Pool(x)) + Conv(DepthConv(Conv(Norm(x), stride=2))).

In depth-wise convolution, convolution is applied to a single channel at a time, that is, each channel of the input data convolves with a dedicated kernel. So, the filters/kernels will be of size k × k × 1 (see figure below).

Results

The original paper on CoAtNets reports several good results. However, it is worth noting that, as of March 2022, the state-of-the-art in image classification on ImageNet is represented by a CoAtNet model (CoAtNet-7, Top Accuracy: 90.88%, 2440M parameters, here for more).

Useful links

CoAtNet: Marrying Convolution and Attention for All Data Sizes
Z. Dai, H. Liu, Q. V. Le, M. Tan
arXiv:2106.04803v2 [cs.CV] (2021).

Stand-Alone Self-Attention in Vision Models
P. Ramachandran, N. Parmar, A. Vaswani, I. Bello, A. Levskaya, J. Shlens
arXiv:1906.05909v1 [cs.CV] (2019).

2D Convolution (link).

Multi-Head Attention (link).

Code implementations (PyTorch and TensorFlow) (link).

Leave a comment