Image

paper , blog , code

TL;DR

  • I read this because.. : April is linear transformer – first paper
  • task : autoregressive sequence modeling, language modeling, machine translation
  • problem : self-attention compares all pairs of tokens, time/memory O(N²), inefficient for long sequences
  • idea : Replace softmax attention with kernel form $\phi(Q)\phi(K)^T$ and rearrange it as a combinatorial law, which allows cumulative sum
  • input/output : token -> token
  • architecture: replaced softmax with kernel with elu function. No other architectural changes
  • objective : CE loss
  • baseline : Transformer, RoFormer
  • data : WMT, language modeling benchmark
  • evaluation : BLEU (MT), perplexity (LM)
  • Result: large speed/memory improvement on long sequences, slight decrease in performance or similar
  • contribution: reinterpreted attention as a kernel, proposed O(N) linear attention, showed that transformer behaves like an RNN
  • . etc. : causal masking is naturally embedded in the prefix sum structure, enabling layerwise parallelism

Details

  • conversation with chatGPT: link
  • X$

3.1 Transformer

Image
  • f_l(.)$ is just FFN
  • $A_l(.)$ self-attnetion
Image

We could have just expressed the softmax term there as the similarity function $sim(\cdot)$

Image

3.2. Linearized attention

This is where I get confused, we’re going to use something called a Kernel Trick. The only constraint on attention is that $sim(\cdot)$ must be “non-negative” Then, among all kernels , $k(x,y) : \mathbb{R}^{2 \times F} -> \mathbb{R}_{+}$ can be contained.

Suppose we have such an “imagination kernel” ($k$), then rewriting eq (2) for the feature representation $\phi(x)$, we get

Image

Above, $\sum _j$ is the value for j, so we can pass over $\phi(Q_i)^T$, which gives us the expression Image

In this case, the feature map $\phi(\cdot)$ is computed row-wise on the $Q$, $K$ matrices The parentheses in eq (6) are $\phi(X)^T\in \mathbb{R}^{D\times N}$, followed by $\phi(X)^T\in \mathbb{R}^{N\times D}$, resulting in a time and space complexity of $O(N)$. (I’m confused about the spatial and temporal complexity…) – The reason for this is that we will store and reuse KV, K once. Image

Feature maps and computational cost

Choose the elu function because the computational cost depends on which kernel you use Image

I used relu over elu because I wanted the gradient to flow even when below zero.

3.3 Causal Masking

How do I get Transformer’s Causal masking here? This can be accomplished by changing the summation to $i$ instead of doing it for all j

(previous expression)

Image

(w/ causal masking) Image

We can calculate $S_{i}$ from $S_{i-1}$. because it is a cumulative sum. I was confused when I first read this, but it seems like you’re using cumulative sums for inference and applying causal masks like the original transformer for actual training.

3.3.1 Gradient Computation

If we solve for the gradient naively, we get another $O(N^2)$ complexity, but we solve nicely and make this one linear as well

Image

3.3.2 Training and Inference

The good thing about Transformer is that you don’t need to have QKs for inference, so memory doesn’t grow proportional to seq len. In other words, it takes all the good things about train and inference Image

3.4. Transformers are RNNs

Image

Experiment

Skip h