JetMoE

Also available on Substack

Support this blog

Concise Grok

AI generated image
The main function

Mixture of Experts from scratch

Photo by Alice Pasqual

Neural Networks as Decision Trees

More or less obvious transpositions

Photo by Larisa Birta

This post is inspired by a recent article (which we will not cover) stating that neural networks are decision trees. Clearly it is not the only article to address the topic. Paying too much attention to certain articles — stating that neural networks are decision trees, compositions of splines, kernel machines — one may end up believing that neural networks are equivalent to any ML construct one chooses to name…

ReLU activations naturally determine tree structures

The following simple argument is from the paper Towards Interpretable ANNs: An Exact Transformation to Multi-Class Multivariate Decision Trees by Nguyena, Kasmarika and Abbass. Consider a feed-forward neural network whose hidden layers are activated by ReLU function. Fix a hidden layer, say the k-th — we think of it as assigned just to omit the index k. The index j refers to this layer, the index i refers to the preceding layer (the (k-1)-th layer). Denote with zj the value of the hidden node j in layer k before the activation:

\displaystyle z_j = \sum_{i= 1}^{I} w_{i\, j} \, H_{i} + b_ j \,.

The H values are the activations from the preceding layer (inputs to k-th layer) and bj is a bias term. The post-activation values h may coincide with the z values or not (in this case ReLU activation returns 0). The possibilities are depicted in the following figure.

Due to the nature of ReLU activation, the output of a node after applying ReLU activation is either 0 or the same value to the input to that node, prior activation (that is, hj = zj). Thus, it is easy to see that each hidden layer of the neural network can be transformed into a binary decision tree. Decision at each tree stage is made by the activation of the corresponding node in the hidden layer based on the constraint of whether or not the value before the activation function is greater than 0.

As for explainability, it is clear that the size of the tree grows exponentially as the network size grows; it’s like going from one black box to another.

C-Net

There is a method for generating multivariate decision trees (MDTs) from neural networks. We present the first C-Net architecture (there is a new version which we will not cover). The procedure is the following. After the neural network is trained, new data is introduced and the outputs of the last hidden layer are computed. In other words, from a set of training and test data, denoted with <Xt, Yt> and <XT, YT> respectively, we can compute the mapping between the last hidden output layer and the output, denoted as <Ht, Yt> and <HT, YT>. We retain these two sets, representing the relationship between the last hidden layer and the output layer, for the next stage in which they are used to train a Quinlan C5 univariate decision tree (UDT) whose algorithm adapts an entropic information gain ratio for branch-splitting criterion. After that, we know that a decision tree can be represented by a set of polyhedrons expressed in the form of linear constraints. These constraints have the form Hj(Xt) op Cj , where op represents the binary operators {≤, <, =, >, ≥}, and Cj is the numeric threshold of such a constraint on input Hj . To obtain a multivariate for of the expression, a back-projection from the output of the neural network to the input of the neural network is needed.

The algorithm is the following.

Useful links

Neural Networks are Decision Trees
C.Aytekin
arXiv:2210.05189 [cs.LG], 2022.

Towards Interpretable ANNs: An Exact Transformation to Multi-Class Multivariate Decision Trees
D. T. Nguyen, K. E. Kasmarik, H. A. Abbass
arXiv:2003.04675 [cs.LG], 2020.

C-Net: A Method for Generating Non-deterministic and Dynamic Multivariate Decision Trees
H. A. Abbass, M. Towsey, G. D. Finn
Knowledge and Information SystemsVolume, 3 Issue, pp. 184–197, 2001 (link).

Rectifier (ReLu activation) – Wikipedia entry.

The illusion of learning (link).

Explainable AI – Wikipedia entry.

Forward-Forward algorithm

Photo by Tolga Ulkan

2022 has gone away with Hinton’s last effort — The Forward-Forward Algorithm: Some Preliminary Investigations. It is not my intention to stir up a controversy about Hinton, but to this day it still escapes me what his real contribution to neural networks is. Last time I covered an article by Hinton was for Capsule Networks (what happened to them?) a few years ago.

There are some issues with backpropagation: first, even if neural networks are somewhat modeled on real neuronal functioning, backpropagation does not exist biologically; second, everything one puts into a neural network (as a black box) has to be modeled as a differentiable module to work well with backpropagation.

Main idea

Hinton’s last paper introduces the Forward-Forward (FF) learning method with the following key features:

(a) FF replaces forward and backward passes of backpropagation by two forward passes; one operates on positive values and the other operates on negative values;

(b) each layer has its own objective function, that is, a measure of goodness for positive and negative data;

(c) FF computes the gradients locally using a local objective function, so there is no need to backpropagate the errors.

Looking at a piece of the implementation code for the layer train method, the input is literally split into positive and negative values to operate on.

Learning with a simple layer-wise goodness function

The sum of the squared activities in a layer can be used as the “goodness” but there are many other possibilities, including minus the sum of the squared activities. Specifically, we look to correctly classify input vectors as positive data or negative data when the probability that an input vector is positive is given by the following (θ is a threshold term and σ denotes the logistic function):

\displaystyle p(\mathsf{positive}) =  \sigma\left( \sum_j y_j^2 - \theta \right)\,.

A single hidden layer can be learned using the following criterion: the sum squared activities of the hidden units has to be high for positive data (over a certain threshold value θ for sure) and low for negative data.

A necessary observation: since it is trivial to distinguish positive from negative data by simply using the length of activity vector in the first hidden layer as an input to the second hidden layer (no need to learn new features), FF normalizes the length of the hidden vector before using it as input to the next layer. Briefly, the activity vector in the first hidden layer has a length and an orientation: the length is used to define the goodness for that layer the orientation (only) is passed to the next layer.

A supervised example

To implement supervised learning with FF, one way is to include the class labels in the input (see figure below).

An image with the correct label constitutes the positive data and an image with incorrect label constitutes the negative data. The only difference between positive and negative data is the label, so FF should ignore all image features that do not correlate with the label.

After training on MNIST dataset using FF, it is possible to classify a test digit running the net with a particular label as part of the input and accumulate the goodnesses of all but the first hidden layer. This has to be done for each label separately. After that, the label with the highest accumulated goodness is chosen. The paper reports that, during training, in order to pick hard negative labels, a forward pass from a neutral label was used.

With MNIST, after training all the layers, to make a prediction for a test image x, we find the pair (x, y) for all labels y (where y in {0, 1,…, 9}) that maximizes the network’s overall activation.

Performance

Hinton’s paper reports a brief comparison between FF and backpropagation on CIFAR-10. The test performance of FF is slightly worse than backpropagation. There is also an interesting page about the analysis of performance versus backpropagation.

Code implementations

I’d like to mention two GitHub repositories, one form Nebuly-ai and the other from Mohammad Pezeshki. Both are PyTorch implementations.

Useful Links

The Forward-Forward Algorithm: Some Preliminary Investigations
G. Hinton
arXiv:2212.13345 [cs.LG], 2022.

Code from Nebuly-ai.

Code from M. Pezeshki.

Detailed Backpropagation Algorithm (link).

Interesting performance analysis page.

Dropout tales

Insights into a popular regularization technique

Photo by Olga Tutunaru

Dropout is an effective regularization technique used to reduce overfitting in neural networks. It works like this: given a feedforward neural network, at training time remove some neurons at each non-output layer, depending on a certain probability. For example, if the probability is 0.5 (this probability can vary across levels), you flip a coin and decide if a certain neuron should be in or out. The picture below shows the given original network (called base or parent network) on the left and the network after applying dropout on the right.

As you may have noticed, at training time the network on the right (picture above) is a simpler network (less units) prone to express a simpler model.

At test time no units are dropped, so that the full network is used to make predictions. The picture below shows what happens at training time and during test.

At training time, a unit (neuron) is present with a certain probability p and is connected to units in the next layer with weights w. At test time, the unit is always present and weights are multiplied by p. This is because we would like the outputs of units during test time to be equivalent to their expected outputs at training time.

In fact, dropout retains a unit with probability p and removes a unit (the output of a unit is set to 0) with probability 1 − p. This means that if the output of a unit prior to dropout was x, then after dropout the expected output would be E[output] = px + (1 − p) · 0 = px. Therefore, to ensure that the outputs have the same expectation at test time as they did during training, we have to multiply weights by p at test time.

However, this implementation of dropout is undesirable because it requires scaling of neuron outputs at test time. This is bad for test-time performance, so it is preferable to use inverted dropout, where the scaling occurs at training time instead of testing time. In inverted dropout, the output of any retained unit is divided by p before the value is propagated to the next layer. In this case

\displaystyle \text{E}[\text{output}] = p \cdot \frac{x}{p} + (1 - p) \cdot 0\,,

avoiding output scaling at test time.

Dropout as a regularizer

By the fact that units can go away at random, each neuron may miss an important input (or more important inputs) from the previous layer and so it can not rely on any one input. The neuron has to spread out the weights with respect to its incoming neurons, causing the weights to shrink. This shrinking lowers the squared norm of the weights. Hence dropout is, in some respects, similar to L2 regularization. This explanation can be found in this video lecture by Andrew Ng.

The fact that some type of L2 regularization was hiding behind the dropout technique was already discussed in a 2013 article. One of the study findings is that dropout can be seen as an attempt to apply an L2 penalty after normalizing the feature vector by a quantity depending on the diagonal of an estimate of Fisher information matrix.

In the picture above, a comparison of two L2 regularizers (take a look at this page if you need a quick recap on regularization). The solid ellipses are level surfaces of the likelihood and the dashed curves are level surfaces of the regularizer. The top panel shows a classic spherical L2 regularizer. Let I be the Fisher information matrix. If I were a multiple of the identity matrix, then these level surfaces would be perfectly spherical. In dropout, these level surfaces are non-spherical (bottom panel) due to the normalization of the problem features by diag(I)⁻¹ᐟ²: L2 penalty is applied after scaling (the features have been balanced out).

Dropout as a bagging algorithm

There is an obvious link between intuitive regularization and size/complexity of the network (see picture below). Smaller networks correspond to rigid and simple models. It would be useful sometimes — to avoid overfitting — to exploit a method that helps reducing complexity, returning a better performing model. Intuitively, fewer neurons (units) in action correspond to simpler models.

As you may have noticed, at training time the network (after applying dropout) is a simpler network (less units) prone to express a simpler model, maybe reducing overfitting. The network is trained to produce accurate predictions on unseen data even in unfriendly conditions where some neurons are missing.

Recall that to learn with bagging, we define t different learners (ensemble models), construct t different datasets by sampling from the training set with replacement, and then train model i on dataset i. The bagging meta-algorithm is depicted below: (1) create multiple data sets Dᵢ through sampling with replacement; (2) employ multiple learners Lᵢ in parallel; (3) combine all learners using an averaging or majority-vote strategy.

Dropout aims to approximate this process, but with an exponentially large number of neural networks. Dropout trains the ensemble consisting of (possibly all) subnetworks that can be formed by removing non-output units from a given base network (see figure below). The base network can be identified with its 2 thinned subnetworks.

When training with dropout, we use minibatches and each time we load an example into a minibatch, we randomly sample a different binary mask (0 out, 1 in) applying to all of the input and hidden units in the network.

There is a significant difference between bagging and dropout. Bagging models are all independent. Dropout models, instead, share parameters: each model inherits a different subset of parameters from the parent neural network. This parameter sharing makes it possible to represent an exponential number of models with a tractable amount of memory. Moreover, dropout training differs from bagging in that each model is trained for only one step.

In bagging, the prediction of the ensemble is given by the arithmetic mean of all of the resulting predictions. In the case of dropout, at test time it is not feasible to explicitly average the predictions from exponentially many thinned models. However, there is a simple approximate averaging method that works well in practice. There is no theoretical reason (at the moment) for the accuracy of this approximate averaging method, but empirically it performs very well. The idea is to use a single neural net at test time without dropout. This neural net is obtained adjusting the weights as shown before, i.e. outgoing weights of a retained unit are multiplied by p at test time. We already observed that this ensures that, for any hidden unit, the actual output at test time is the same as the expected output at training time. By doing this scaling, a large number of networks with shared weights can be combined into a single neural network to be used at test time.

Dropout in practice

Dropout is implemented in PyTorch through the nn.Dropout class. nn.Dropout randomly zeroes some of the elements of the input tensor with probability p using samples from a Bernoulli distribution. Note that here p is the probability to drop the unit; this is different from our previous usage (so far we have denoted with p the probability to retain a unit).

Below, a minimal example showing how Dropout sets to zero several units of matrix x (setting p = 0.75, about 3 units out of 4 are dropped).

import torch
from torch.nn import Dropout

x = torch.full((3, 5), 1.0)
print(x)
dropout = Dropout(p = 0.75)
y = dropout(x)
print(y)

The TensorFlow analogue is tf.keras.layers.Dropout. Below, a small neural network example with nn.Dropout modules interspersed between Linear layers.

import torch
from torch.nn import Sequential, Linear, ReLU, Dropout

model = Sequential(Linear(10, 100), ReLU(),
                   Dropout(),
                   Linear(100, 50), ReLU(),
                   Dropout(),
                   Linear(50, 2))
t = torch.rand(10)
print(model(t))

If the neural network is defined as a class, it is possible to specify nn.Dropout occurrences in the forward method.

What are the best values for p ? There is no right value that works for all kinds of situations, the key is to repeat experiments until a satisfactory value is reached. As initial values to be refined later, some sources cite that, typically, an input unit is included with probability 0.8 (p = 0.2) and a hidden unit is included with probability 0.5 (p = 0.5).

Useful links

Dropout: A Simple Way to Prevent Neural Networks from Overfitting
N. Srivastava, G. Hinton, A. Krizhevsky, I. Sutskever, R. Salakhutdinov
Journal of Machine Learning Research 15 (1929–1958), 2014 [link].

Fundamentals of Deep Learning
N. Buduma, N. Lacascio
36–37, O’Reilly Media, Inc., 2017.

Dropout Training as Adaptive Regularization
S. Wager , S. Wang , P. Liang
arXiv:1307.1493 [stat.ML], 2013.

Deep Learning
I. Goodfellow, Y. Bengio, A. Courville
Chapter 7 (224–270), MIT Press, 2016 [link].

Dropout — PyTorch docs page.

How does dropout work during testing in neural network?[Prylipko]

Neural Interpreters

Sparse attention mechanisms and several analogies with programming codes


Photo by Shubham Dhage

For this post we refer to the paper “Dynamic Inference with Neural Interpreters” by Rahaman et al. (2021).

Overview

neural interpreter is a collection of modules almost resembling a programming code: it is a bunch of scripts which are made up of functions which are made up of lines of code. Essentially, this is an attention-based network and inputs to the model are routed through a sequence of functions in a way that is end-to-end learned.

Convolutional networks reuse computational units, like filters, laterally (once depth is fixed), meanwhile recurrent neural networks only reuse computational units (RNN cells) vertically, i.e., in depth. Such rigidity in the way networks reuse their units is believed to be one of the reasons for the poor generalization. Neural interpreter model aims to be an architecture made of independent and composable pieces, capable of relaxing this rigidity in computation reuse.

Input and Output

Assume that the input set contains vector embeddings of image patches or entire images. These elements are vectors of a certain dimension din. The input set additionally includes one or more learned vectors, called CLS tokens, for which the corresponding outputs interface with their respective classifiers. The output is another set of vectors whose dimension is dout (with the same cardinality as the input set).

Fig. 1. Neural Interpreter

Scripts

A neural interpreter is a stack of nₛ scripts mapping one set of vectors = {x₁ , x₂, …} to another Y = {y₁ , y₂, …} with the same number of elements:

\mathbf{Y} = \mathsf{Neural \; Interpreter}(\mathbf{X}) = \left[ \mathsf{Script}_{n_s} \, \circ \, \cdots \, \circ\, \mathsf{Script}_1 \right](\mathbf{X})

Fig. 2. A neural interpreter is a stack of scripts

Increasing the number of scripts nₛ will increase the depth of the architecture. A script has four components:

  1. a type inference module;
  2. a type matching mechanism;
  3. a set of functions;
  4. an interpreter.

We will soon describe these four components.

Functions

Each script contains functions. Functions are vector-valued instructions to other components in the script. Formally, a function fᵤ is a pair (sc) where s is called signature and c is called code (u is used as index). The signature is a normalized vector of dimensions dtype and indicates to the type matching mechanism (see below) what inputs are to be routed to fᵤ (note the analogy with coding)The vector c, a learned parameter for each function, is the code that tells the function what to do (further details in a moment). Each f has its own code that would always be the same.

Functions inside a script
Fig. 3. Functions inside a script

For example, f₁, f₂ and f₃ all share their global parameters but they all have their own codes. Samples can jump flexibly from one function to another. The way each sample is routed through the network is completely independent and it is determined on a per-sample basis. Every example has its own independent path to the network and the routing itself is completely learned.

Not all the examples are routed, so let’s see how an example gets to the functions’ scope.

Type Matching and Inference

Before getting to the functions, a sort of higher-level attention is performed on the set elements. Type matching is responsible for routing the information elements through functions. This is a three step procedure.

a) At the beginning, an input set element x is processed through an MLP module (called type inference module) to obtain a type vector t whose dimension is dtype. This vector lies in the same unit hypersphere 𝓣 containing the signature vectors s.

b) Consider a function fᵤ. Define a distance function based on the cosine similarity between t and signature s, that is d𝓣= 1 – sᵤ · t .

c) Successively, a sort of softmax with normalization is performed, returning a coefficient Cᵤᵢ. However, Cᵤᵢ is set to 0 if the distance between s and tᵢ is larger than τ, a value called truncation parameter. This introduces sparsity in the model. Fix u and i, then Cᵤᵢ is the compatibility between function fᵤ and set element xx can be processed by fᵤ only if Cᵤᵢ is sufficiently large. If Cᵤᵢ = 0, then fᵤ cannot access x.

Fig. 4. Type matching and inference

Modulated Linear layers and modulated MLPs

The following constructs are needed to define an attention mechanism later on. These constructs should be interpreted as programmable modules (the program is determined by the code c). Modulated linear layers act like linear layers with the only difference being that, instead of x, the linear transformation is applied to

x´ = x ⊗ LayerNorm(W𝒸 c)

where W𝒸 is a learnable matrix that constitutes a set of parameters shared among all functions in the same script (the symbol ⊗ denotes entry-wise product). In short

\mathbf{y} = \mathsf{ModLin}(\mathbf{x}; \mathbf{c} ) = \mathbf{W}\mathbf{x}^\prime + \mathbf{b}

where W is a weight matrix and b is a bias term. Having defined modulated linear layer, one may also stack L of them (sharing the same code c) interspersed with GELU activation functions to get the modulated MLP:

\begin{aligned} \mathbf{y} &= \mathsf{ModMLP}(\mathbf{x}; \mathbf{c} )\\ &= ( \mathsf{ModLin}_L(\bullet; \mathbf{c} ) \, \circ \, \mathsf{Activation} \, \cdots \, \circ \, \mathsf{ModLin}_1(\bullet; \mathbf{c} ) )(\mathbf{x})\,. \end{aligned}

ModAttn

A type of conditional (that is, conditioned by the code vector c of function fᵤmulti-head attention mechanism is used. Queries, keys and values are evaluated using ModLin layers (instead of simple linear layers) for each head h:

\begin{aligned} \mathbf{k}_{uhi} &= \mathsf{ModLin}_{\textsf{key}}^h(\mathbf{x}_i \,; \, \mathbf{c}_u )\\ \mathbf{q}_{uhi} &= \mathsf{ModLin}_{\textsf{query}}^h(\mathbf{x}_i \,; \, \mathbf{c}_u )\\ \mathbf{v}_{uhi} &= \mathsf{ModLin}_{\textsf{value}}^h(\mathbf{x}_i \,; \, \mathbf{c}_u ) \,. \end{aligned}

Then, consider again the compatibility coefficients {Cᵤᵢ}; these quantities would serve as modulators when evaluating self-attention weights. Self-attention weights are given by the normalizing expression

\displaystyle W_{uhij} = \frac{\tilde{W}_{uhij}}{\epsilon + \sum \tilde{W}_{uhij}}

where epsilon avoids divisions by ~0 terms and

\displaystyle \tilde{W}_{uhij} = C_{ui}C_{uj}\left[ \mathsf{softmax}_j \left(\frac{ \mathbf{q}_{uhi} \cdot \mathbf{k}_{uhj} }{\sqrt{d_\textsf{key}}}\right) \right] \,.

For example, fix fᵤ and the head h. Then we have

\displaystyle \tilde{W}_{ij} = C_{i}C_{j}\left[ \mathsf{softmax}_j \left(\frac{ \mathbf{q}_{i} \cdot \mathbf{k}_{j} }{\sqrt{d_\textsf{key}}}\right) \right]

and, after normalization, the weight Wij is the attention weight between elements x and x. Intuitively, information about x and x is mixed by fᵤ at head h only if Wuhij is not 0. This can happen in two cases: 1) the compatibility factors are both non-zero (that is, fᵤ can access both x and x) or 2) self-attention weights (the softmax part) is close to zero. Finally, the following linear combination is computed

\displaystyle \tilde{\mathbf{y}}_{uhi} = \sum_j W_{uhij}\mathbf{v}_{uhj}

and the final output is

\displaystyle \tilde{\mathbf{y}}_{ui} = \mathsf{ModLin}(\tilde{\mathbf{y}}_{ui;h}\,;\,\mathbf{c}_u)

where the semicolon separating h from ui indicates that the results of various heads are folded (as usual in multi-head attention) into one single object.

Line of Code

line of code layer is a ModAttn layer followed by a ModMLP layer (see figure below, on the right). Both these layers share the same condition vector and there are weighted residual connections between them.

Fig. 5. Lines of code

line of code (LOC) is a line of code layer applied in parallel streams, one per function, as shown in Fig. 5 (right). Inputs of a LOC, say {xᵤᵢ}, are written with an extra index u, meaning that this is a specific input to the function fᵤ. If a function fᵤ cannot access xᵤᵢ, then fᵤ acts on xᵤᵢ as the identity function. For example, focus on a particular function fᵤ and on its specific inputs {xᵤᵢ} as i vary. Then

\mathbf{a}_{ui} = \mathbf{x}_{ui} +C_{ui} \tilde{\mathbf{a}}_{ui}

where ãᵤᵢ is the output of the attention layer (ModAttn); then

\mathbf{y}_{ui} = \mathbf{a}_{ui} +C_{ui} \tilde{\mathbf{b}}_{ui}

where \mathbf{\tilde{b}}_{ui} is the output of the MLP module (essentially, a ModMLP layer). Note that if fᵤ cannot access xᵤᵢ (that is, Cᵤᵢ = 0), then the output yᵤᵢ is just xᵤᵢ.

Interpreter

The interpreter layer is a stack of LOCs sharing the same function codes. The interpreter broadcasts a given set element to multiple parallel computational streams, one for each function. Let the number of stacked LOCs be nₗ. Let = {x₁ , x₂, …} and C = {Cᵤ₁, C₂, …}, then

\mathbf{y}_{i} = \mathbf{x}_{i} + C_{1i}\, \mathcal{L}(\mathbf{X}, \mathbf{c}_1, \mathbf{C}_1) + C_{2i}\, \mathcal{L}(\mathbf{X}, \mathbf{c}_2, \mathbf{C}_2) + \cdots

where

\mathcal{L} = \underbrace{\mathsf{LOC}_{n_l}\,\circ\, \cdots\,\circ\,\mathsf{LOC}_1}_{n_l \, \textsf{times}} \,.

Essentially, the output is a weighted sum with compatibilies of the elements with the respective function as coefficients. Given a set of inputs and an instruction (the function code), the role of the interpreter is to execute that instruction and compute the output.

Increasing the number of LOCs nl increases the architecture depth and also the number of parameters.

Functions Iteration

We have already seen that the overall model is a stack of multiple scripts. A script can be expressed as a recurrent application of Function Iteration (FnIter)

\{ \mathbf{y}_1, \mathbf{y}_2, \dots\} =( \underbrace{\mathsf{FnIter}\,\circ\, \cdots\,\circ\,\mathsf{FnIter}}_{n_i \, \textsf{times}})( \{ \mathbf{x}_1, \mathbf{x}_2, \dots\})

where FnIter is defined as the composition of the type matching mechanism and the interpreter.

The number of function iterations nᵢ can increase without increasing the number of parameters, so FnIter can enable units sharing in depth.

Experiments

Some experiments have been conducted on subjects such as learning fuzzy boolean expressions, multi-task image classification abstract reasoning. However, we do not delve any further into such matters as it will only be time to determine whether this recent architecture is profitable or not.

Useful links

Original article on Neural Interpreters.

Nice discussion with authors (video).

Sharpness-Aware Minimization

This post deals with a recent optimizing method for training neural networks described in the paper Sharpness-Aware Minimization for Efficiently Improving Generalization by P. Foret et al. (December 2020). Honestly, the first time I read about the paper details, I really thought the procedure therein described (or something similar) had already been explored many years before by tons of people… I was even surprised to read that it worked in some contexts.

Is loss value not enough?

Modern models train through optimization methods relying just on the training loss. These models can easily memorize the training data and are prone to overfitting. They have more parameters than needed and this large number of parameters provides no guarantee of proper generalization to the test set.

Sharpness-Aware Minimization (SAM) is a procedure that aims to improve model generalization by simultaneously minimizing loss value and loss sharpness (the pictures below provide an intuitive support for the notion of “sharpness” for a loss landscape).

Fig. 1. Sharp vs wide (low curvature) minimum
Fig. 2. Sharp minimum (left) vs wide minimum (right) for a ResNet trained with SGD (source)

SAM seeks parameters lying in neighborhoods having uniformly low loss value (and not just parameters having low loss value). When SAM procedure is used to update weights, the sharpness of the loss landscape is taken into account.

Empirical studies suggest that SAM improves model generalization ability across a range of widely studied computer vision tasks on datasets like CIFAR{10-100} and ImageNet.

A learning setup

A little premise before going into details. Let

S = \left\{(x_1,y_1), \dots, (x_n, y_n) \right\}

We seek to learn a model that generalizes well (roughly, a model that performs well on the test set). We consider a family of models which parameters are wW (w is d-dimensional vector) and a loss function acting on each single datapoint l . A loss function is, tipically, a function expressing the discrepancy between the model prediction and the actual observation (label). We define the train set loss as

\displaystyle L_S(w) = \frac{1}{n}\sum_{i=1}^n l(w,x_i,y_i),

that is, the mean of per-data-point errors over S, and the population loss

L_D(w) =\mathbb{E}_{(x,y)\sim D}\, l(w,x,y)

as the mean per-data-point loss over the whole distribution D.

Which is the goal of model training? Having observed only S, find model parameters w such that the population loss LD(w) is low. In practice, training loss LS(w) is used as an estimate of population loss LD(w), and the model parameters w are selected by solving minw LS(w) using some optimization procedure as Stochastic Gradient Descent (SGD) or Adam.

For modern models, LS(w) is typically a non-convex function of the parameters w. A problem is that this function has multiple local — and even global — minima in which assumes similar values while having signifcantly different generalization performance (that is, the population loss assumes significantly different values).

What makes SAM different is the focus on minima neighborhoods. Rather than seeking out parameter values w that simply have low training loss LS(w), the SAM procedure seeks out parameter values whose neighborhoods have both low loss and low curvature.

Sharpness

\displaystyle L_D(w) \leq \underset{\| \epsilon \|_2 \leq \rho}{\max} L_S(w+\epsilon) + h \left(\frac{\|w\|_2^2}{\rho^2} \right)

\displaystyle \left[ \underset{\| \epsilon \|_2 \leq \rho}{\max} L_S(w+\epsilon) - L_S(w) \right] + L_S(w) + h \left(\frac{\|w\|_2^2}{\rho^2} \right).

The term enclosed by square brackets is the sharpness. Note that the more the loss grows around w (steep landscape), the larger is the sharpness. Sharpness measures how quickly the training loss can be increased by moving from w to a nearby parameter value w + ϵ.

Minimization

The function h is removed in favor of a simpler constant λ (that is not strictly increasing, however…), making the last addendum a standard L2 regularization term. At this point, we propose to choose parameter values by solving the following minimization problem

\displaystyle \underset{w}{\min}\; L_S^{\mathsf{SAM}}(w) + \lambda \|w\|_2^2\,,

where

\displaystyle L_S^{\mathsf{SAM}} = \underset{\| \epsilon \|_p \leq \rho}{\max} L_S(w+\epsilon)

with \rho ≥ 0 as hyperparameter and p in [1, ∞] (a little generalization, though p=2 is empirically the best choice).

In order to minimize L_S^{\mathsf{SAM}}(w) , an efficient approximation of its gradient will be determined. A first step is to consider the first-order Taylor expansion of LS(w + ϵ) around 0, with respect to ϵ, and put it in the \displaystyle L_S^{\mathsf{SAM}} expression. Taking the argument:

\displaystyle \begin{aligned} \epsilon^*(w) &= \arg\underset{\| \epsilon \|_p \leq \rho}{\max} L_S(w+\epsilon) \\ &\approx \arg\underset{\| \epsilon \|_p \leq \rho}{\max} \left(L_S(w) + \epsilon^\top \nabla_w L_S(w) \right) \\ &= \arg\underset{\| \epsilon \|_p \leq \rho}{\max} \epsilon^\top \nabla_w L_S(w). \end{aligned}

The last expression is just the argmax of the dot product of the vectors ϵ and ∇w LS(w), and it is well known which is the argument that maximizes it (check this dual norm result; optimal value is denoted with y). An easy intro to dual norms can be found here. Let’s denote temporarily wLS(w) with g . The argument that solves the preceding approximation is

\displaystyle \hat{\epsilon}(w) = \rho\, \mathrm{sign}(g) \frac{ |g|^{q-1} }{ \left(\| g\|_q^q\right)^{1/p}}

where 1/p + 1/q = 1. Since \hat{\epsilon}(w) \, is the peak argument, we can write

\displaystyle \begin{aligned} \nabla_w\,L_S^{\mathsf{SAM}}(w) &\approx \nabla_w\,L_s(w+ \hat{\epsilon}(w)) \\ &= \frac{\mathrm{d}\,(w + \hat{\epsilon}(w)) }{ \mathrm{d}\, w} \, \nabla_w\,L_s(w)|_{ w+ \hat{\epsilon}(w) } \\ &= \nabla_w\,L_s(w)|_{ w+ \hat{\epsilon}(w) }\, +\, \frac{\mathrm{d}\, \hat{\epsilon}(w) }{ \mathrm{d}\, w} \, \nabla_w\,L_s(w)|_{ w+ \hat{\epsilon}(w) }. \end{aligned}

Modern frameworks can easily compute the preceding approximation. However, to speed up the computation, second-order terms can be dropped obtaining

\displaystyle \nabla_w\,L_S^{\mathsf{SAM}}(w) \approx \nabla_w\,L_s(w)|_{ w+ \hat{\epsilon}(w) }.

Algorithm

Input: training set S, loss function l, batch size b, step size \eta , neighborhood size \rho .
Output: model trained with SAM.

Fig. 3. SAM parameter update

The JAX code from the authors’ paper and additional info can be found here; another implementation in PyTorch is available here.

Original article (2020).

Dual norms result.

Introductory video about dual norms.

Taylor expansion (Wikipedia article).

Code repository from the authors’ paper.

PyTorch implementation available here.

Colab notebook with TensorFlow implementation.

The Kullback-Leibler divergence

In this post we will just spend a few words on a well-known measure of how dissimilar a given distribution is from another reference distribution. First we will give a definition for such a measure and then we will provide some intuitive meaning together with some useful coding snippets.

Definition

Let’s begin with the discrete case. So let P and Q be two probability distributions defined on the same probability space \mathcal{X} . A first attempt may be considering the average of the difference between the distributions. Quite close indeed, the following defintion is just a little bit different. The Kullback-Leibler divergence (also called relative entropy) KL(PQ) is defined as the average of the difference between the logarithms of probabilities P(x) and Q(x):

\mathrm{KL}(P\Vert Q) \, \stackrel{\mathsf{def}}{=} \, \mathbb{E} \big[ \log P(x)  - \log Q(x) \big]\,.

The expectation is taken using the probabilities P (often written as x \sim P). The definition of expectation leads to the expression

\displaystyle \mathrm{KL}(P\Vert Q) = \sum_{x \in \mathcal{X}} P(x) \log\left(\frac{P(x)}{Q(x)}\right).

In the case of continuous distributions we write

\displaystyle \mathrm{KL}(P\Vert Q) = \int_{-\infty}^\infty p(x) \log\left(\frac{p(x)}{q(x)}\right) \,\mathrm{d}x

where p(x) and q(x) are P and Q respective densities.

KL divergence is often called a “distance” but it is not a distance in mathematical sense (a metric): KL divergence is not symmetrical. This means that KL(PQ) is generally different from KL(QP).

If Q(x) is 0 for some x, the KL divergence is not defined unless it is P(x) = 0. What if P is 0 somewhere? In this case, we interpret that the KL divergence must be zero since when a approaches 0, the expression alog(a) tends to 0 .

Motivations behind the definition

A first intuition comes form the fact that if {pi} and {qi} are two probability mass functions, that is, two countable or finite sequences of nonnegative numbers that sum to one, then

\displaystyle  \sum_{i} p_i \log \left(\frac{p_i}{q_i}\right) \geq 0

with equality if and only if pi = qi for all i. The fact that the divergence of one probability distribution with respect to another is nonnegative and zero only when the two distributions are the same suggests the interpretation of KL divergence as a “distance” between two distributions, that is, a measure of how different the two distributions are.

A second intuition about the fact that KL divergence actually expresses some kind of distance between two distributions comes from the expression

\begin{aligned} \displaystyle \mathrm{KL}(P\Vert Q) &= \int_{-\infty}^\infty p(x) \left( \log p(x) - \log q(x) \right) \, \mathrm{d}x \\& = \int_{-\infty}^\infty p(x) D(x)\, \mathrm{d}x \end{aligned}

where it is immediate to recognize that the difference between logarithms D(x) is a term expressing the gap between the two distributions. If the average gap is small, then the two distributions are “similar” or “close”.

Fig. 1. Two continuous distribution densities p and q and their respective logarithmic transformations log(p) and log(q)

Connection with cross entropy

KL divergence KL(PQ) is equal to

\begin{aligned} \displaystyle \mathrm{KL}(P\Vert Q) &= - \sum_x P(x)  \log Q(x) + \sum_x P(x) \log P(x) \\& = H(P,Q) + H(P) \end{aligned}

where H(P , Q) is the cross entropy of P and Q and H(P) is the entropy of P. As we said, KL(PQ) can be thought of as something like a measurement of how far the distribution Q is from the distribution P. But cross entropy is itself such a measurement… the difference is that cross entropy has a — generally nonzero — minimum when P = Q, that is H(P , P) = H(P); so in KL divergence we subtract the entropy term H(P) to attain minimum value 0. This is coherent with the property that the distance of an object from itself should be zero.

Quick example

Let P and Q be the following distributions (each possible outcome x is in \mathcal{X} = {0, 1, 2}):

012
Distribution P(x)9 / 2512 / 254 / 25
Distribution Q(x)1 / 31 / 31 / 3
Fig. 2. The distributions P and Q

The following picture shows both P (amber) and Q (gray).

Fig. 3. P and Q overlapping

Next picture shows the logarithm of distributions with the difference D at x = 2.

Fig. 4. log(P) and log(Q) with difference D

Let’s calculate KL(PQ).

\begin{aligned} \displaystyle \mathrm{KL}(P\Vert Q) &= \sum_x P(x)  \log \left( \frac{P(x)}{Q(x)} \right) \\&= 9/25 \log\left(\frac{9/25}{1/3}\right) + 12/25 \log\left(\frac{12/25}{1/3} \right) + 4/25 \log\left(\frac{4/25}{1/3} \right) \\& \approx 0.0853\,. \end{aligned}

Interchanging the arguments, we find that KL(QP) is approximately 0.0974 and this value is different from the previous.

Evaluate KL divergence with Python

Import the entropy function

from scipy.stats import entropy

and then compute KL(PQ) from the example above in just one line.

entropy([9/25, 12/25, 4/25], qk=[1/3, 1/3, 1/3])
0.0852996013183706

Below, a simple Python coding example for figures 1~4. Note that the two continuous density curves have a magnifying coefficient for scaling purposes.

import matplotlib.pyplot as plt 
import numpy as np 

p = [9/25, 12/25, 4/25]
q = [1./3,1./3,1./3]
xx = ['0','1','2']

logq = np.log(q)
logp = np.log(p)

plt.bar(xx, q, color='beige')
plt.bar(xx, p, alpha=.6, color='yellowgreen')
plt.show()

plt.bar(xx, logq, color='beige')
plt.bar(xx, logp, alpha=.6, color='yellowgreen')
plt.show()
from scipy.stats import norm, skewnorm

x = np.arange(-3,2.5,.001)
plt.plot(x, 10*skewnorm.pdf(x,-1.2), color='black')
plt.plot(x, 10*norm.pdf(x, scale=1.1), color='yellowgreen')
log1 = np.log(skewnorm.pdf(x,-1.2))
log2 = np.log(norm.pdf(x, scale=1.1))
plt.plot(x, log1, color='black')
plt.plot(x, log2, color='yellowgreen')
plt.fill_between(x, log1, log2, 
                 where=log1>=log2, facecolor='darkgrey', 
                 interpolate=True)
plt.fill_between(x, log1, log2, 
                 where=log1<log2, facecolor='lightgreen', 
                 interpolate=True)
plt.show()

Feel free to email me for comments, questions, suggestions or if you just want to leave a message.

Multinomial Logistic Regression

In this post we will show how to quickly build a simple model for digits classification using TensorFlow 2 on MNIST dataset. We just need base Python, NumPy, Matplotlib and a recent version of TensorFlow (2.X).

The model

Our aim is to train a simple classifier with SGD (or other optimizers like Adam) using TensorFlow.

The idea is to apply a very simple transformation to the input, obtaining a vector that, suitably adjusted, expresses information about which class the input belongs to.

Let W be a matrix and b a bias vector (both trainable). Starting from (flattened) input x, through the linear transformation Wx+b we obtain a resulting logits vector which is then squashed with softmax function to get a probability distribution. The i-th entry of this distribution vector represents the probability that the input belongs to the i-th class (see figure below).

For each class (each digit in the case of MNIST dataset) we need to calculate a logit (using a linear function)

z_k = w_k \cdot x + b_k \quad (k=0,\dots,9)

and transform logits to valid probabilities p_k with softmax

\displaystyle p_k = \frac{e^{z_k}}{\sum_{i=0}^9 e^{z_i} } \quad k=0,\dots,9 .

For our model, we can assume that x is a flattened vector coming from a digit image and w_k is a row from a weight matrix. The model just described is known by a variety of names, including Multinomial Logistic Regression and Softmax Regression.

We will use cross-entropy loss to train our multi-class classifier. In particular, since we have labels representing digit classes that are integers (and not one-hot vectors), TensorFlow has a nice loss function that fits this case: SparseCategoricalCrossentropy.

The code

We start by importing NumPy, Matplotlib and TensorFlow.

import numpy as np

from matplotlib import pyplot as plt
%matplotlib inline

import tensorflow as tf
print("We're using TF", tf.__version__)

Our MNIST dataset consists of 50000 28×28 images of digits from 0 to 9. We will train a classifier on this data.

from tensorflow.keras.datasets import mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

Time for some dataset visualization.

print("x_train [shape %s] sample patch:\n" % (str(x_train.shape)), 
      x_train[1, 15:20, 5:10])
print("A closeup of a sample patch:")
plt.imshow(x_train[1, 15:20, 5:10], cmap="Greys")
plt.show()

print("And the whole sample:")
plt.imshow(x_train[0], cmap="Greys")
plt.show()

print("y_train [shape %s] 10 samples:\n" % (str(y_train.shape)),
      y_train[:10])

Normalize image values from [0,255] to [0,1].

x_train, x_test = x_train / 255., x_test / 255.

Here’s our (very simple) model, Keras-style built.

model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(10, activation='softmax')])

model.summary()

Training over 10 epochs we get an accuracy ~93%.

model.compile(optimizer='adam',
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),
    metrics=['accuracy'])

model.fit(x_train, y_train, epochs=10)

Visualizing results

We compare some predicted digits with the actual digits.

predictions = model.predict(x_test)
predictions = np.argmax(predictions, axis=1)
print(predictions[:10])
print(y_test[:10])
[7 2 1 0 4 1 4 9 6 9]
[7 2 1 0 4 1 4 9 5 9]

Just one mismatch on the first 10 examples.

n_to_show = 8
indices = np.random.choice(range(len(x_test)), n_to_show)

fig = plt.figure(figsize=(15, 3))
fig.subplots_adjust(hspace=0.4, wspace=0.4)

for i, idx in enumerate(indices):
    img = x_test[idx]
    ax = fig.add_subplot(1, n_to_show, i+1)
    ax.axis('off')
    ax.text(0.5, -0.4, 
            'predicted = ' + str(predictions[idx]),
            fontsize=10, 
            ha='center',
            transform=ax.transAxes)
    ax.text(0.5, -0.7, 
            'actual = ' + str(y_test[idx]),
            fontsize=10, 
            ha='center', 
            transform=ax.transAxes)
    ax.imshow(img, cmap='binary')

[Jupyter Notebook]

Feel free to email me for comments, questions, suggestions or if you just want to leave a message.