AI News HubLIVE
Original source19 min read

The Transformer Family Version 2.0

This article is a major update to Lilian Weng's 2020 post on the Transformer family, doubling its length. It systematically reviews numerous recent improvements to the Transformer architecture, covering attention mechanisms, positional encoding, long-context support, adaptive modeling, and efficient attention, including the latest advances such as Transformer-XL, Rotary position embedding, ALiBi, and the Universal Transformer.

The Transformer Family Version 2.0 | Lil'Log

Lil'Log

|

Posts

Archive

Search

Tags

FAQ

The Transformer Family Version 2.0

Date: January 27, 2023 | Estimated Reading Time: 45 min | Author: Lilian Weng

Table of Contents

Notations

Transformer Basics

Attention and Self-Attention

Multi-Head Self-Attention

Encoder-Decoder Architecture

Positional Encoding

Sinusoidal Positional Encoding

Learned Positional Encoding

Relative Position Encoding

Rotary Position Embedding

Longer Context

Context Memory

Non-Differentiable External Memory

Distance-Enhanced Attention Scores

Make it Recurrent

Adaptive Modeling

Adaptive Attention Span

Depth-Adaptive Transformer

Efficient Attention

Sparse Attention Patterns

Fixed Local Context

Strided Context

Combination of Local and Global Context

Content-based Attention

Low-Rank Attention

Transformers for Reinforcement Learning

Citation

References

Many new Transformer architecture improvements have been proposed since my last post on “The Transformer Family” about three years ago. Here I did a big refactoring and enrichment of that 2020 post — restructure the hierarchy of sections and improve many sections with more recent papers. Version 2.0 is a superset of the old version, about twice the length.

Notations#

Symbol Meaning

$d$ The model size / hidden state dimension / positional encoding size.

$h$ The number of heads in multi-head attention layer.

$L$ The segment length of input sequence.

$N$ The total number of attention layers in the model; not considering MoE.

$\mathbf{X} \in \mathbb{R}^{L \times d}$ The input sequence where each element has been mapped into an embedding vector of shape $d$, same as the model size.

$\mathbf{W}^k \in \mathbb{R}^{d \times d_k}$ The key weight matrix.

$\mathbf{W}^q \in \mathbb{R}^{d \times d_k}$ The query weight matrix.

$\mathbf{W}^v \in \mathbb{R}^{d \times d_v}$ The value weight matrix. Often we have $d_k = d_v = d$.

$\mathbf{W}^k_i, \mathbf{W}^q_i \in \mathbb{R}^{d \times d_k/h}; \mathbf{W}^v_i \in \mathbb{R}^{d \times d_v/h}$ The weight matrices per head.

$\mathbf{W}^o \in \mathbb{R}^{d_v \times d}$ The output weight matrix.

$\mathbf{Q} = \mathbf{X}\mathbf{W}^q \in \mathbb{R}^{L \times d_k}$ The query embedding inputs.

$\mathbf{K} = \mathbf{X}\mathbf{W}^k \in \mathbb{R}^{L \times d_k}$ The key embedding inputs.

$\mathbf{V} = \mathbf{X}\mathbf{W}^v \in \mathbb{R}^{L \times d_v}$ The value embedding inputs.

$\mathbf{q}_i, \mathbf{k}_i \in \mathbb{R}^{d_k}, \mathbf{v}_i \in \mathbb{R}^{d_v}$ Row vectors in query, key, value matrices, $\mathbf{Q}$, $\mathbf{K}$ and $\mathbf{V}$.

$S_i$ A collection of key positions for the $i$-th query $\mathbf{q}_i$ to attend to.

$\mathbf{A} \in \mathbb{R}^{L \times L}$ The self-attention matrix between a input sequence of lenght $L$ and itself. $\mathbf{A} = \text{softmax}(\mathbf{Q}\mathbf{K}^\top / \sqrt{d_k})$.

$a_{ij} \in \mathbf{A}$ The scalar attention score between query $\mathbf{q}_i$ and key $\mathbf{k}_j$.

$\mathbf{P} \in \mathbb{R}^{L \times d}$ position encoding matrix, where the $i$-th row $\mathbf{p}_i$ is the positional encoding for input $\mathbf{x}_i$.

Transformer Basics#

The Transformer (which will be referred to as “vanilla Transformer” to distinguish it from other enhanced versions; Vaswani, et al., 2017) model has an encoder-decoder architecture, as commonly used in many NMT models. Later simplified Transformer was shown to achieve great performance in language modeling tasks, like in encoder-only BERT or decoder-only GPT.

Attention and Self-Attention#

Attention is a mechanism in neural network that a model can learn to make predictions by selectively attending to a given set of data. The amount of attention is quantified by learned weights and thus the output is usually formed as a weighted average.

Self-attention is a type of attention mechanism where the model makes prediction for one part of a data sample using other parts of the observation about the same sample. Conceptually, it feels quite similar to non-local means. Also note that self-attention is permutation-invariant; in other words, it is an operation on sets.

There are various forms of attention / self-attention, Transformer (Vaswani et al., 2017) relies on the scaled dot-product attention: given a query matrix $\mathbf{Q}$, a key matrix $\mathbf{K}$ and a value matrix $\mathbf{V}$, the output is a weighted sum of the value vectors, where the weight assigned to each value slot is determined by the dot-product of the query with the corresponding key:

$$ \text{attn}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{softmax}(\frac{\mathbf{Q} {\mathbf{K}}^\top}{\sqrt{d_k}})\mathbf{V} $$

And for a query and a key vector $\mathbf{q}_i, \mathbf{k}_j \in \mathbb{R}^d$ (row vectors in query and key matrices), we have a scalar score:

$$ a_{ij} = \text{softmax}(\frac{\mathbf{q}_i {\mathbf{k}_j}^\top}{\sqrt{d_k}}) = \frac{\exp(\frac{\mathbf{q}_i {\mathbf{k}_j}^\top}{\sqrt{d_k}})}{ \sum_{r \in \mathcal{S}_i} \exp(\frac{\mathbf{q}_i {\mathbf{k}_r}^\top}{\sqrt{d_k}}) } $$

where $\mathcal{S}_i$ is a collection of key positions for the $i$-th query to attend to.

See my old post for other types of attention if interested.

Multi-Head Self-Attention#

The multi-head self-attention module is a key component in Transformer. Rather than only computing the attention once, the multi-head mechanism splits the inputs into smaller chunks and then computes the scaled dot-product attention over each subspace in parallel. The independent attention outputs are simply concatenated and linearly transformed into expected dimensions.

$$ \begin{aligned} \text{MultiHeadAttn}(\mathbf{X}_q, \mathbf{X}_k, \mathbf{X}_v) &= [\text{head}_1; \dots; \text{head}_h] \mathbf{W}^o \\ \text{where head}_i &= \text{Attention}(\mathbf{X}_q\mathbf{W}^q_i, \mathbf{X}_k\mathbf{W}^k_i, \mathbf{X}_v\mathbf{W}^v_i) \end{aligned} $$

where $[.;.]$ is a concatenation operation. $\mathbf{W}^q_i, \mathbf{W}^k_i \in \mathbb{R}^{d \times d_k/h}, \mathbf{W}^v_i \in \mathbb{R}^{d \times d_v/h}$ are weight matrices to map input embeddings of size $L \times d$ into query, key and value matrices. And $\mathbf{W}^o \in \mathbb{R}^{d_v \times d}$ is the output linear transformation. All the weights should be learned during training.

Illustration of the multi-head scaled dot-product attention mechanism. (Image source: Figure 2 in Vaswani, et al., 2017)

Encoder-Decoder Architecture#

The encoder generates an attention-based representation with capability to locate a specific piece of information from a large context. It consists of a stack of 6 identity modules, each containing two submodules, a multi-head self-attention layer and a point-wise fully connected feed-forward network. By point-wise, it means that it applies the same linear transformation (with same weights) to each element in the sequence. This can also be viewed as a convolutional layer with filter size 1. Each submodule has a residual connection and layer normalization. All the submodules output data of the same dimension $d$.

The function of Transformer decoder is to retrieve information from the encoded representation. The architecture is quite similar to the encoder, except that the decoder contains two multi-head attention submodules instead of one in each identical repeating module. The first multi-head attention submodule is masked to prevent positions from attending to the future.

The architecture of the vanilla Transformer model. (Image source: Figure 17)

Positional Encoding#

Because self-attention operation is permutation invariant, it is important to use proper positional encoding to provide order information to the model. The positional encoding $\mathbf{P} \in \mathbb{R}^{L \times d}$ has the same dimension as the input embedding, so it can be added on the input directly. The vanilla Transformer considered two types of encodings:

Sinusoidal Positional Encoding#

Sinusoidal positional encoding is defined as follows, given the token position $i=1,\dots,L$ and the dimension $\delta=1,\dots,d$:

$$ \text{PE}(i,\delta) = \begin{cases} \sin(\frac{i}{10000^{2\delta'/d}}) & \text{if } \delta = 2\delta'\\ \cos(\frac{i}{10000^{2\delta'/d}}) & \text{if } \delta = 2\delta' + 1\\ \end{cases} $$

In this way each dimension of the positional encoding corresponds to a sinusoid of different wavelengths in different dimensions, from $2\pi$ to $10000 \cdot 2\pi$.

Sinusoidal positional encoding with $L=32$ and $d=128$. The value is between -1 (black) and 1 (white) and the value 0 is in gray.

Learned Positional Encoding#

Learned positional encoding assigns each element with a learned column vector which encodes its absolute position (Gehring, et al. 2017) and furthermroe this encoding can be learned differently per layer (Al-Rfou et al. 2018).

Relative Position Encoding#

Shaw et al. (2018)) incorporated relative positional information into $\mathbf{W}^k$ and $\mathbf{W}^v$. Maximum relative position is clipped to a maximum absolute value of $k$ and this clipping operation enables the model to generalize to unseen sequence lengths. Therefore, $2k + 1$ unique edge labels are considered and let us denote $\mathbf{P}^k, \mathbf{P}^v \in \mathbb{R}^{2k+1}$ as learnable relative position representations.

$$ A_{ij}^k = P^k_{\text{clip}(j - i, k)} \quad A_{ij}^v = P^v_{\text{clip}(j - i, k)} \quad \text{where }\text{clip}(x, k) = \text{clip}(x, -k, k) $$

Transformer-XL (Dai et al., 2019) proposed a type of relative positional encoding based on reparametrization of dot-product of keys and queries. To keep the positional information flow coherently across segments, Transformer-XL encodes the relative position instead, as it could be sufficient enough to know the position offset for making good predictions, i.e. $i-j$, between one key vector $\mathbf{k}_{\tau, j}$ and its query $\mathbf{q}_{\tau, i}$.

If omitting the scalar $1/\sqrt{d_k}$ and the normalizing term in softmax but including positional encodings, we can write the attention score between query at position $i$ and key at position $j$ as:

$$ \begin{aligned} a_{ij} &= \mathbf{q}_i {\mathbf{k}_j}^\top = (\mathbf{x}_i + \mathbf{p}_i)\mathbf{W}^q ((\mathbf{x}_j + \mathbf{p}_j)\mathbf{W}^k)^\top \\ &= \mathbf{x}_i\mathbf{W}^q {\mathbf{W}^k}^\top\mathbf{x}_j^\top + \mathbf{x}_i\mathbf{W}^q {\mathbf{W}^k}^\top\mathbf{p}_j^\top + \mathbf{p}_i\mathbf{W}^q {\mathbf{W}^k}^\top\mathbf{x}_j^\top + \mathbf{p}_i\mathbf{W}^q {\mathbf{W}^k}^\top\mathbf{p}_j^\top \end{aligned} $$

Transformer-XL reparameterizes the above four terms as follows:

$$ a_{ij}^\text{rel} = \underbrace{ \mathbf{x}_i\mathbf{W}^q \color{blue}{ {\mathbf{W}_E^k}^\top } \mathbf{x}_j^\top }_\text{content-based addressing} + \underbrace{ \mathbf{x}_i\mathbf{W}^q \color{blue}{ {\mathbf{W}_R^k}^\top } \color{green}{\mathbf{r}_{i-j}^\top} }_\text{content-dependent positional bias} + \underbrace{ \color{red}{\mathbf{u}} \color{blue}{ {\mathbf{W}_E^k}^\top } \mathbf{x}_j^\top }_\text{global content bias} + \underbrace{ \color{red}{\mathbf{v}} \color{blue}{ {\mathbf{W}_R^k}^\top } \color{green}{\mathbf{r}_{i-j}^\top} }_\text{global positional bias} $$

Replace $\mathbf{p}_j$ with relative positional encoding $\mathbf{r}_{i-j} \in \mathbf{R}^{d}$;

Replace $\mathbf{p}_i\mathbf{W}^q$ with two trainable parameters $\mathbf{u}$ (for content) and $\mathbf{v}$ (for location) in two different terms;

Split $\mathbf{W}^k$ into two matrices, $\mathbf{W}^k_E$ for content information and $\mathbf{W}^k_R$ for location information.

Rotary Position Embedding#

Rotary position embedding (RoPE; Su et al. 2021) encodes the absolution position with a rotation matrix and multiplies key and value matrices of every attention layer with it to inject relative positional information at every layer.

When encoding relative positional information into the inner product of the $i$-th key and the $j$-th query, we would like to formulate the function in a way that the inner product is only about the relative position $i-j$. Rotary Position Embedding (RoPE) makes use of the rotation operation in Euclidean space and frames the relative position embedding as simply rotating feature matrix by an angle proportional to its position index.

Given a vector $\mathbf{z}$, if we want to rotate it counterclockwise by $\theta$, we can multiply it by a rotation matrix to get $R\mathbf{z}$ where the rotation matrix $R$ is defined as:

$$ R = \begin{bmatrix} \cos\theta & -\sin\theta \\ \sin\theta & \cos\theta \end{bmatrix} $$

When generalizing to higher dimensional space, RoPE divide the $d$-dimensional space into $d/2$ subspaces and constructs a rotation matrix $R$ of size $d \times d$ for token at position $i$:

$$ R^d_{\Theta, i} = \begin{bmatrix} \cos i\theta_1 & -\sin i\theta_1 & 0 & 0 & \dots & 0 & 0 \\ \sin i\theta_1 & \cos i\theta_1 & 0 & 0 & \dots & 0 & 0 \\ 0 & 0 & \cos i\theta_2 & -\sin i\theta_2 & \dots & 0 & 0 \\ 0 & 0 & \sin i\theta_2 & \cos i\theta_2 & \dots & 0 & 0 \\ \vdots & \vdots & \vdots & \vdots & \ddots & \vdots & \vdots \\ 0 & 0 & 0 & 0 & \dots & \cos i\theta_{d/2} & -\sin i\theta_{d/2} \\ 0 & 0 & 0 & 0 & \dots & \sin i\theta_{d/2} & \cos i\theta_{d/2} \\ \end{bmatrix} $$

where in the paper we have $\Theta = {\theta_i = 10000^{-2(i−1)/d}, i \in [1, 2, …, d/2]}$. Note that this is essentially equivalent to sinusoidal positional encoding but formulated as a rotation matrix.

Then both key and query matrices incorporates the positional information by multiplying with this rotation matrix:

$$ \begin{aligned} & \mathbf{q}_i^\top \mathbf{k}_j = (R^d_{\Theta, i} \mathbf{W}^q\mathbf{x}_i)^\top (R^d_{\Theta, j} \mathbf{W}^k\mathbf{x}_j) = \mathbf{x}_i^\top\mathbf{W}^q R^d_{\Theta, j-i}\mathbf{W}^k\mathbf{x}_j \\ & \text{ where } R^d_{\Theta, j-i} = (R^d_{\Theta, i})^\top R^d_{\Theta, j} \end{aligned} $$

Visual illustration of how rotary position embedding is implemented.(Image source: Su et al., 2021) Note: I used $i$ instead of $m$ to represent the position index compared to the original figure in the paper.

Longer Context#

The length of an input sequence for transformer models at inference time is upper-bounded by the context length used for training. Naively increasing context length leads to high consumption in both time ($\mathcal{O}(L^2d)$) and memory ($\mathcal{O}(L^2)$) and may not be supported due to hardware constraints.

This section introduces several improvements in transformer architecture to better support long context at inference; E.g. using additional memory, design for better context extrapolation, or recurrency mechanism.

Context Memory#

The vanilla Transformer has a fixed and limited attention span. The model can only attend to other elements in the same segments during each update step and no information can flow across separated fixed-length segments. This context segmentation causes several issues:

The model cannot capture very long term dependencies.

It is hard to predict the first few tokens in each segment given no or thin context.

The evaluation is expensive. Whenever the segment is shifted to the right by one, the new segment is re-processed from scratch, although there are a lot of overlapped tokens.

Transformer-XL (Dai et al., 2019; “XL” means “extra long”) modifies the architecture to reuse hidden states between segments with an additional memory. The recurrent connection between segments is introduced into the model by continuously using the hidden states from the previous segments.

A comparison between the training phrase of vanilla Transformer & Transformer-XL with a segment length 4. (Image source: left part of Figure 2 in Dai et al., 2019).

Let’s label the hidden state of the $n$-th layer for the $(\tau + 1)$-th segment in the model as $\mathbf{h}_{\tau+1}^{(n)} \in \mathbb{R}^{L \times d}$. In addition to the hidden state of the last layer for the same segment $\mathbf{h}_{\tau+1}^{(n-1)}$, it also depends on the hidden state of the same layer for the previous segment $\mathbf{h}_{\tau}^{(n)}$. By incorporating information from the previous hidden states, the model extends the attention span much longer in the past, over multiple segments.

$$ \begin{aligned} \color{red}{\widetilde{\mathbf{h}}_{\tau+1}^{(n-1)}} &= [\text{stop-gradient}(\mathbf{h}_{\tau}^{(n-1)}) \circ \mathbf{h}_{\tau+1}^{(n-1)}] \\ \mathbf{Q}_{\tau+1}^{(n)} &= \mathbf{h}_{\tau+1}^{(n-1)}\mathbf{W}^q \\ \mathbf{K}_{\tau+1}^{(n)} &= \color{red}{\widetilde{\mathbf{h}}_{\tau+1}^{(n-1)}} \mathbf{W}^k \\ \mathbf{V}_{\tau+1}^{(n)} &= \color{red}{\widetilde{\mathbf{h}}_{\tau+1}^{(n-1)}} \mathbf{W}^v \\ \mathbf{h}_{\tau+1}^{(n)} &= \text{transformer-layer}(\mathbf{Q}_{\tau+1}^{(n)}, \mathbf{K}_{\tau+1}^{(n)}, \mathbf{V}_{\tau+1}^{(n)}) \end{aligned} $$

Note that both keys and values rely on extended hidden states, while queries only consume hidden states at the current step. The concatenation operation $[. \circ .]$ is along the sequence length dimension. And Transformer-XL needs to use relative positional encoding because previous and current segments would be assigned with the same encoding if we encode absolute positions, which is undesired.

Compressive Transformer (Rae et al. 2019) extends Transformer-XL by compressing past memories to support longer sequences. It explicitly adds memory slots of size $m_m$ per layer for storing past activations of this layer to preserve long context. When some past activations become old enough, they are compressed and saved in an additional compressed memory of size $m_{cm}$ per layer.

Compressive transformer maintains two types of memory slots, memory and compressed memory, to support long context. (Image source: Rae et al. 2019).

Both memory and compressed memory are FIFO queues. Given the model context length $L$, the compression function of compression rate $c$ is defined as $f_c: \mathbb{R}^{L \times d} \to \mathbb{R}^{[\frac{L}{c}] \times d}$, mapping $L$ oldest activations to $[\frac{L}{c}]$ compressed memory elements. There are several choices of compression functions:

Max/mean pooling of kernel and stride size $c$;

1D convolution with kernel and stride size $c$ (need to learn additional parameters);

Dilated convolution (need to learn additional parameters). In their experiments, convolution compression works out the best on EnWik8 dataset;

Most used memories.

Compressive transformer has two additional training losses:

Auto-encoding loss (lossless compression objective) measures how well we can reconstruct the original memories from compressed memories

$$ \mathcal{L}_{ac} = \| \textbf{old_mem}^{(i)} - g(\textbf{new_cm}^{(i)}) \|_2 $$

where $g: \mathbb{R}^{[\frac{L}{c}] \times d} \to \mathbb{R}^{L \times d}$ reverses the compression function $f$.

Attention-reconstruction loss (lossy objective) reconstructs content-based attention over memory vs compressed memory and minimize the difference:

$$ \mathcal{L}_{ar} = \|\text{attn}(\mathbf{h}^{(i)}, \textbf{old_mem}^{(i)}) − \text{attn}(\mathbf{h}^{(i)}, \textbf{new_cm}^{(i)})\|_2 $$

Transformer-XL with a memory of size $m$ has a maximum temporal range of $m \times N$, where $N$ is the number of layers in the model, and attention cost $\mathcal{O}(L^2 + Lm)$. In comparison, compressed transformer has a temporal range of $(m_m + c \cdot m_{cm}) \times N$ and attention cost $\mathcal{O}(L^2 + L(m_m + m_{cm}))$. A larger compression rate $c$ gives better tradeoff between temporal range length and attention cost.

Attention weights, from oldest to newest, are stored in three locations: compressed memory → memory → causally masked sequence. In the experiments, they observed an increase in attention weights from oldest activations stored in the regular memory, to activations stored in the compressed memory, implying that the network is learning to preserve salient information.

Attention weights with one standard deviation as error bars versus memory positions, from oldest (left) to newest (right). (Image source: Rae et al. 2019).

Non-Differentiable External Memory#

$k$NN-LM (Khandelwal et al. 2020) enhances a pretrained LM with a separate $k$NN model by linearly interpolating the next token probabilities predicted by both models. The $k$NN model is built upon an external key-value store which can store any large pre-training dataset or OOD new dataset. This datastore is preprocessed to save a large number of pairs, (LM embedding representation of context, next token) and the nearest neighbor retrieval happens in the LM embedding space. Because the datastore can be gigantic, we need to rely on libraries for fast dense vector search such as FAISS or ScaNN. The indexing process only happens once and parallelism is easy to implement at inference time.

At inference time, the next token probability is a weighted sum of two predictions:

$$ \begin{aligned} p(y \vert \mathbf{x}) &= \lambda \; p_\text{kNN}(y \vert \mathbf{x}) + (1- \lambda) \; p_\text{LM}(y \vert \mathbf{x}) \\ p_\text{kNN}(y \vert \mathbf{x}) &\propto \sum_{(k_i, w_i) \in \mathcal{N}} \mathbb{1}[y = w_i] \exp(-d(k_i, f(\mathbf{x}))) \end{aligned} $$

where $\mathcal{N}$ contains a set of nearest neighbor data points retrieved by $k$NN; $d(., .)$ is a distance function such as L2 distance.

According to the experiments, larger datastore size or larger $k$ is correlated with better perplexity. The weighting scalar $\lambda$ should be tuned, but in general it is expected to be larger for out-of-domain data compared to in-domain data and larger datastore can afford a larger $\lambda$.

SPALM (Adaptive semiparametric language models; Yogatama et al. 2021) incorporates both (1) Transformer-XL style memory for hidden states from external context as short-term memory and (2) $k$NN-LM style key-value store as long memory.

Illustration of how SPALM combines context memory of past hidden states (short term memory) with an external key-value datastore (long term memory) to support longer context. (Image source: Yogatama et al. 2021).

SPALM runs $k$NN search to fetch $k$ tokens with most relevant context. For each token we can get the same embedding representation provided by a pretrained LM, denoted as $\{\mathbf{y}_i\}_{i=1}^k$. The gating mechanism first aggregates the retrieved token embeddings with a simple attention layer using $\mathbf{h}^R_t$ (the hidden state for token $x_t$ at layer $R$) as a query and then learns a gating parameter $\mathbf{g}_t$ to balance between local information $\mathbf{h}^R_t$ and long-term information $\mathbf{m}_t$.

$$ \begin{aligned} \mathbf{m}_t &= \sum_{i=1}^k \frac{\exp(\mathbf{y}_i^\top \mathbf{h}^R_t)}{\sum_{j=1}^k \exp(\mathbf{y}_j^\top \mathbf{h}^R_t)} \cdot \mathbf{y}_i \\ \mathbf{g}_t &= \sigma(\mathbf{w}_g^\top \mathbf{h}_t^R) \\ \mathbf{z}_t &= (1 - \mathbf{g}_t) \odot \mathbf{m}_t + \mathbf{g}_t \odot \mathbf{h}^R_t \\ p(x_{t+1}\mid \mathbf{x}_{\leq t}) &= \text{softmax}(\mathbf{z}_t; \mathbf{W}) \end{aligned} $$

where $\mathbf{w}_g$ is a parameter vector to learn; $\sigma(.)$ is sigmoid; $\mathbf{W}$ is the word embedding matrix shared between both input and output tokens. Different from $k$NN-LM, they didn’t find the nearest neighbor distance to be helpful in the aggregation of retrieved tokens.

During training, the key representations in the long-term memory stay constant, produced by a pretrained LM, but the value encoder, aka the word embedding matrix, gets updated.

Memorizing Transformer (Wu et al. 2022) adds a $k$NN-augmented attention layer near the top stack of a decoder-only Transformer. This special layer maintains a Transformer-XL style FIFO cache of past key-value pairs.

The same QKV values are used for both local attention and $k$NN mechanisms. The $k$NN lookup returns top-$k$ (key, value) pairs for each query in the input sequence and then they are processed through the self-attention stack to compute a weighted average of retrieved values. Two types of attention are combined with a learnable per-head gating parameter. To prevent large distributional shifts in value magnitude, both keys and values in the cache are normalized.

What they found during experiments with Memorizing Transformer:

It is observed in some experiments that training models with a small memory and then finetuned with a larger memory works better than training with a large memory from scratch.

The smaller Memorizing Transformer with just 8k tokens in memory can match the perplexity of a larger vanilla Transformer with 5X more trainable parameters.

Increasing the size of external memory provided consistent gains up to a size of 262K.

A non-memory transformer can be finetuned to use memory.

Fine-tuning a vanilla Transformer with a key-value memory can achieve similar performance as training a memorizing transformer from scratch. (Image source: Wu et al. 2022).

Distance-Enhanced Attention Scores#

Distance Aware Transformer(DA-Transformer; Wu, et al. 2021) and Attention with Linear Biases (ALiBi; Press et al. 2022) are motivated by similar ideas — in order to encourage the model to extrapolate over longer context than what the model is trained on, we can explicitly attach the positional information to every pair of attention score based on the distance between key and query tokens.

Note that the default positional encoding in vanilla Transformer only adds positional information to the input sequence, while later improved encoding mechanisms alter attention scores of every layer, such as rotary position embedding, and they take on form very similar to distance enhanced attention scores.

DA-Transformer (Wu, et al. 2021) multiplies attention scores at each layer by a learnable bias that is formulated as a function of the distance between key and query. Different attention heads use different parameters to distinguish diverse preferences to short-term vs long-term context. Given two positions, $i, j$, DA-Transformer uses the following weighting function to alter the self-attention score:

$$ \begin{aligned} \mathbf{R}^{(i)} &= \alpha_i \mathbf{R} \quad \text{where }R_{ij} = \vert i-j \vert\\ f(\mathbf{R}^{(i)}; \beta_i) &= \frac{1 + \exp(\beta_i)}{1 + \exp(\beta_i - \mathbf{R}^{(i)})} \\ \text{attn}(\mathbf{Q}^{(i)}, \mathbf{K}^{(i)}, \mathbf{V}^{(i)}) &= \text{row-softmax}\Big(\frac{\text{ReLU}(\mathbf{Q}^{(i)}\mathbf{K}^{(i)\top})f(\mathbf{R}^{(i)})}{\sqrt{d}}\Big) \mathbf{V}^{(i)} \end{aligned} $$

where $\alpha_i$ is a learnable parameters to weight relative distance differently per head where the head is indexed by superscript $^{(i)}$; $\beta_i$ is a learnable parameter to control the upper bound and ascending slope wrt the distance for the $i$-th attention head. The weighting function $f(.)$ is designed in a way that: (1) $f(0)=1$; (2) $f(\mathbf{R}^{(i)}) = 0$ when $\mathbf{R}^{(i)} \to -\infty$; (3) $f(\mathbf{R}^{(i)})$ is bounded when $\mathbf{R}^{(i)} \to +\infty$; (4) the scale is tunable; (5) and the function is monotonic. The extra time complexity brought by $f(\mathbf{R}^{(i)})$ is $\mathcal{O}(L^2)$ and it is small relative to the self attention time complexity $\mathcal{O}(L^2 d)$. The extra memory consumption is minimal, ~$\mathcal{O}(2h)$.

Instead of multipliers, ALiBi (Press et al. 2022) adds a constant bias term on query-key attention scores, proportional to pairwise distances. The bias introduces a strong recency preference and penalizes keys that are too far away. The penalties are increased at different rates within different heads. $$ \text{softmax}(\mathbf{q}_i \mathbf{K}^\top + \alpha_i \cdot [0, -1, -2, \dots, -(i-1)]) $$ where $\alpha_i$ is a head-specific weighting scalar. Different from DA-transformer, $\alpha_i$ is not learned but fixed as a geometric sequence; for example, for 8 heads, ${\alpha_i} = {\frac{1}{2}, \frac{1}{2^2}, \dots, \frac{1}{2^8}}$. The overall idea is very much similar to what relative positional encoding aims to solve.

Illustration of how ALiBi enhances attention scores with a positional bias term. (Image source: Press et al. 2021).

With ALiBi, Press et al. (2022) trained a 1.3B model on context length 1024 during training and extrapolated to 2046 at inference time.

Extrapolation experiments for running inference with Transformers of different configs, including sinusoidal positional encoding, rotary positional encoding, simplified relative positional encoding in T5 and ALiBi. All models were trained with small context length but inference ran for much longer context. (Image source: Press et al. 2021).

Make it Recurrent#

Universal Transformer (Dehghani, et al. 2019) combines self-attention in Transformer with the recurrent mechanism in RNN, aiming to benefit from both a long-term global receptive field of Transformer and learned inductive biases of RNN. Rather than going through a fixed number of layers, Universal Transformer dynamically adjusts the number of steps using adaptive computation time. If we fix the number of steps, an Universal Transformer is equivalent to a multi-layer Transformer with shared parameters across layers.

On a high level, the universal transformer can be viewed as a recurrent function for learning the hidden state representation per token. The recurrent function evolves in parallel across token positions and the information between positions is shared through self-attention.

How the Universal Transformer refines a set of hidden state representations repeatedly for every position in parallel. (Image source: Figure 1 in Dehghani, et al. 2019).

Given an input sequence of length $L$, Universal Transformer iteratively updates the representation $\mathbf{h}^t \in \mathbb{R}^{L \times d}$ at step $t$ for an adjustable number of steps. At step 0, $\mathbf{h}^0$ is initialized to be same as the input embedding matrix. All the positions are processed in parallel in the multi-head self-attention mechanism and then go through a recurrent transition function.

$$ \begin{aligned} \mathbf{A}^t &= \text{LayerNorm}(\mathbf{h}^{t-1} + \text{MultiHeadAttention}(\mathbf{h}^{t-1} + \mathbf{P}^t) \\ \mathbf{h}^t &= \text{LayerNorm}(\mathbf{A}^{t-1} + \text{Transition}(\mathbf{A}^t)) \end{aligned} $$

where $\text{Transition}(.)$ is either a separable convolution or a fully-connected neural network that consists of two position-wise (i.e. applied to each row of $\mathbf{A}^t$ individually) affine transformation + one ReLU.

The positional encoding $\mathbf{P}^t$ uses sinusoidal position signal but with an additional time dimension:

$$ \text{PE}(i, t, \delta) = \begin{cases} \sin(\frac{i}{10000^{2\delta'/d}}) \oplus \sin(\frac{t}{10000^{2\delta'/d}}) & \text{if } \delta = 2\delta'\\ \cos(\frac{i}{10000^{2\delta'/d}}) \oplus \cos(\frac{t}{10000^{2\delta'/d}}) & \text{if } \delta = 2\delta' + 1\\ \end{cases} $$

A simplified illustration of Universal Transformer. The encoder and decoder share the same basic recurrent structure. But the decoder also attends to final encoder representation $\mathbf{h}^T$. (Image source: Figure 2 in Dehghani, et al. 2019)

In the adaptive version of Universal Transformer, the number of recurrent steps $T$ is dynamically determined by ACT. Each position is equipped with a dynamic ACT halting mechanism. Once a per-token recurrent block halts, it stops taking more recurrent updates but simply copies the current value to the next step until all the blocks halt or until the model reaches a maximum step limit.

Adaptive Modeling#

Adaptive modeling refers to a mechanism that can adjust the amount of computation according to different inputs. For example, some tokens may only need local information and thus demand a shorter attention span; Or some tokens are relatively easier to predict and do not need to be processed through the entire attention stack.

Adaptive Attention Span#

One key advantage of Transformer is the capability of capturing long-term dependencies. Depending on the context, the model may prefer to attend further sometime than others; or one attention head may had different attention pattern from the other. If the attention span could adapt its length flexibly and only attend further back when needed, it would help reduce both computation and memory cost to support longer maximum context size in the model.

This is the motivation for Adaptive Attention Span. Sukhbaatar et al (2019) proposed a self-attention mechanism that seeks an optimal attention span. They hypothesized that different attention heads might assign scores differently within the same context window (See Fig. 14) and thus the optimal span would be trained separately per head.

Two attention heads in the same model, A & B, assign attention differently within the same context window. Head A attends more to the recent tokens, while head B look further back into the past uniformly. (Image source: Sukhbaatar, et al. 2019)

Given the $i$-th token, we need to compute the attention weights between this token and other keys within its attention span of size $s$:

$$ \begin{aligned} e_{ij} &= \mathbf{q}_i {\mathbf{k}_j}^\top \\ a_{ij} &= \text{softmax}(e_{ij}) = \frac{\exp(e_{ij})}{\sum_{r=i-s}^{i-1} \exp(e_{ir})} \\ \mathbf{y}_i &= \sum_{r=i-s}^{i-1}a_{ir}\mathbf{v}_r = \sum_{r=i-s}^{i-1}a_{ir}\mathbf{x}_r\mathbf{W}^v \end{aligned} $$

A soft mask function $m_z$ is added to control for an effective adjustable attention span, which maps the distance between query and key into a [0, 1] value. $m_z$ is parameterized by $z \in [0, s]$ and $z$ is to be learned:

$$ m_z(x) = \text{clip}(\frac{1}{R}(R+z-x), 0, 1) $$

where $R$ is a hyper-parameter which defines the softness of $m_z$.

The soft masking function used in the adaptive attention span. (Image source: Sukhbaatar, et al. 2019.)

The soft mask function is applied to the softmax elements in the attention weights:

$$ a_{ij} = \frac{m_z(i-j)\exp(s_{ij})}{\sum_{r=i-s}^{i-1}m_z(i-r) \exp(s_{ir})} $$

In the above equation, $z$ is differentiable so it is trained jointly with other parts of the model. Parameters $z^{(i)}, i=1, \dots, h$ are learned separately per head. Moreover, the loss function has an extra L1 penalty on $\sum_{i=1}^h z^{(i)}$.

Using Adaptive Computation Time, the approach can be further enhanced to have flexible attention span length, adaptive to the current input dynamically. The span parameter $z_t$ of an attention head at time $t$ is a sigmoidal function, $z_t = S \sigma(\mathbf{v} \cdot \mathbf{x}_t +b)$, where the vector $\mathbf{v}$ and the bias scalar $b$ are learned jointly with other parameters.

In the experiments of Transformer with adaptive attention span, Sukhbaatar, et al. (2019) found a general tendency that lower layers do not require very long attention spans, while a few attention heads in higher layers may use exceptionally long spans. Adaptive attention span also helps greatly reduce the number of FLOPS, especially in a big model with many attention layers and a large context length.

Depth-Adaptive Transformer#

At inference time, it is natural to assume that some tokens are easier to predict and thus do not require as much computation as others. Therefore we may only process its prediction through a limited number of layers to achieve a good balance between speed and performance.

Both Depth-Adaptive Transformer (Elabyad et al. 2020) and Confident Adaptive Language Model (CALM; Schuster et al. 2022) are motivated by this idea and learn to predict optimal numbers of layers needed for different input tokens.

Depth-adaptive transformer (Elabyad et al. 2020) attaches an output classifier to every layer to produce exit predictions based on activations of that layer. The classifier weight matrices can be different per layer or shared across layers. During training, the model sample different sequences of exits such that the model is optimized with hidden states of different layers. The learning objective incorporates likelihood probabilities predicted at different layers, $n=1, \dots, N$:

$$ \text{LL}^n_t = \log p(y_t \vert \mathbf{h}^n_{t-1}) \quad \text{LL}^n = \sum_{t=1}^{\vert\mathbf{y}\vert} LL^n_t $$

Adaptive depth classifiers outputs a parametric distribution $q_t$. It is trained with cross entropy loss against an oracle distribution $q^*_t$. The paper explored three confiurations for how to learn such a classifier $q_t$.

Illustration of three types of adaptive depth classifiers. (Image source: Elabyad et al. 2020).

Sequence-specific depth classifier: All tokens of the same sequence share the same exit block. It depends on the average of the encoder representation of the sequence. Given an input sequence $\mathbf{x}$ of length $L$, the classifier takes $\bar{\mathbf{x}} = \frac{1}{L} \sum_{t=1}^L \mathbf{x}_t$ as input and outputs a multinomial distribution of $N$ dimensions, corresponding to $N$ layers.

$$ \begin{aligned} q(n \vert \mathbf{x}) &=\text{softmax}(\mathbf{W}_n \bar{\mathbf{x}} + b_n) \in \mathbb{R}^N \\ q_\text{lik}^*(\mathbf{x}, \mathbf{y}) &= \delta(\arg\max_n \text{LL}^n - \lambda n) \\ \text{or }q_\text{corr}^*(\mathbf{x}, \mathbf{y}) &= \delta(\arg\max_n C^n - \lambda n) \text{ where }C^n = \vert\{t \vert y_t = \arg\max_y p(y \vert \mathbf{h}^n_{t-1})\}\vert \\ \end{aligned} $$

where $\delta$ is dirac delta (unit impulse) function and $-\lambda n$ is a regularization term to encourage lower layer exits. The ground truth $q^*$ can be prepared in two way, based on maximum likelihood $q_\text{lik}^*$ or correctness $q_\text{corr}^*$.

Token-specific depth classifier (multinomial): Each token is decoded with different exit block, predicted conditioned on the first decoder hidden state $\mathbf{h}^1_t$:

$$ q_t(n \vert \mathbf{x}, \mathbf{y}_{Token-specific depth classifier (geometric-like): A binary exit prediction distribution is made per layer per token, $\mathcal{X}^n_t$. The RBF kernel $\kappa(t, t’) = \exp(\frac{\vert t - t’ \vert^2}{\sigma})$ is used to smooth the predictions to incorporate the impact of current decision on future time steps.

$$ \begin{aligned} \mathcal{X}^n_t &= \text{sigmoid}(\mathbf{w}_n^\top \mathbf{h}^n_t + b_n)\quad \forall n \in [1, \dots, N-1] \\ q_t(n \vert \mathbf{x}, \mathbf{y}_{

At inference time, the confidence threshold for making an exit decision needs to be calibrated. Depth-adaptive transformer finds such a threshold on a validation set via grid search. CALM (Schuster et al. 2022) applied the Learn then Test (LTT) framework (Angelopoulos et al. 2021) to identify a subset of valid thresholds and chose the minimum value as the threshold for inference. Except for training per-layer exit classifier, CALM also explored other methods for adaptive depth prediction, including the softmax responses (i.e. difference between top two softmax outputs) and hidden state saturation (i.e. $\cos(\mathbf{h}^n_t, \mathbf{h}^{n+1}_t)$) as confidence scores for exit decisions. They found softmax responses result in best inference speedup.

Efficient Attention#

The computation and memory cost of the vanilla Transformer grows quadratically with sequence length and hence it is hard to be applied on very long sequences. Many efficiency improvements for Transformer architecture have something to do with the self-attention module - making it cheaper, smaller or faster to run. See the survey paper on Efficient Transformers (Tay et al. 2020).

Sparse Attention Patterns#

Fixed Local Context#

A simple alternation to make self-attention less expensive is to restrict the attention span of each token to local context only, so that self-attention grows linearly with the sequence length.

The idea was introduced by Image Transformer (Parmer, et al 2018), which formulates image generation as sequence modeling using an encoder-decoder transformer architecture:

The encoder generates a contextualized, per-pixel-channel representation of the source image;

Then the decoder autoregressively generates an output image, one channel per pixel at each time step.

Let’s label