TL;DR
- I read this because.. : April is linear transformer – first paper
- task : autoregressive sequence modeling, language modeling, machine translation
- problem : self-attention compares all pairs of tokens, time/memory O(N²), inefficient for long sequences
- idea : Replace softmax attention with kernel form $\phi(Q)\phi(K)^T$ and rearrange it as a combinatorial law, which allows cumulative sum
- input/output : token -> token
- architecture: replaced softmax with kernel with elu function. No other architectural changes
- objective : CE loss
- baseline : Transformer, RoFormer
- data : WMT, language modeling benchmark
- evaluation : BLEU (MT), perplexity (LM)
- Result: large speed/memory improvement on long sequences, slight decrease in performance or similar
- contribution: reinterpreted attention as a kernel, proposed O(N) linear attention, showed that transformer behaves like an RNN
- . etc. : causal masking is naturally embedded in the prefix sum structure, enabling layerwise parallelism
Details
- conversation with chatGPT: link
- X$
3.1 Transformer
- f_l(.)$ is just FFN
- $A_l(.)$ self-attnetion
We could have just expressed the softmax term there as the similarity function $sim(\cdot)$
3.2. Linearized attention
This is where I get confused, we’re going to use something called a Kernel Trick. The only constraint on attention is that $sim(\cdot)$ must be “non-negative” Then, among all kernels , $k(x,y) : \mathbb{R}^{2 \times F} -> \mathbb{R}_{+}$ can be contained.
Suppose we have such an “imagination kernel” ($k$), then rewriting eq (2) for the feature representation $\phi(x)$, we get
Above, $\sum _j$ is the value for j, so we can pass over $\phi(Q_i)^T$, which gives us the expression
In this case, the feature map $\phi(\cdot)$ is computed row-wise on the $Q$, $K$ matrices
The parentheses in eq (6) are $\phi(X)^T\in \mathbb{R}^{D\times N}$, followed by $\phi(X)^T\in \mathbb{R}^{N\times D}$, resulting in a time and space complexity of $O(N)$. (I’m confused about the spatial and temporal complexity…) –
The reason for this is that we will store and reuse KV, K once.
Feature maps and computational cost
Choose the elu function because the computational cost depends on which kernel you use
I used relu over elu because I wanted the gradient to flow even when below zero.
3.3 Causal Masking
How do I get Transformer’s Causal masking here? This can be accomplished by changing the summation to $i$ instead of doing it for all j
(previous expression)
(w/ causal masking)
We can calculate $S_{i}$ from $S_{i-1}$. because it is a cumulative sum. I was confused when I first read this, but it seems like you’re using cumulative sums for inference and applying causal masks like the original transformer for actual training.
3.3.1 Gradient Computation
If we solve for the gradient naively, we get another $O(N^2)$ complexity, but we solve nicely and make this one linear as well
3.3.2 Training and Inference
The good thing about Transformer is that you don’t need to have QKs for inference, so memory doesn’t grow proportional to seq len. In other words, it takes all the good things about train and inference
3.4. Transformers are RNNs
Experiment
Skip h