TL;DR
- I read this because.. : 4์์ linear transformer – ์ฒซ๋ฒ์งธ ๋ ผ๋ฌธ
- task : autoregressive sequence modeling, language modeling, machine translation
- problem : self-attention์ด ๋ชจ๋ ํ ํฐ ์์ ๋น๊ตํด์ ์๊ฐ/๋ฉ๋ชจ๋ฆฌ O(Nยฒ), ๊ธด ์ํ์ค์์ ๋นํจ์จ์
- idea : softmax attention์ kernel ํํ $\phi(Q)\phi(K)^T$๋ก ๋ฐ๊ฟ์ ๊ฒฐํฉ๋ฒ์น์ผ๋ก ์ฌ๋ฐฐ์ด, ์ด๋ฅผ ํตํด cumulative sum์ ํจ
- input/output : token -> token
- architecture : softmax ๋์ kernel with elu function์ผ๋ก ๋ฐ๊ฟ. ๊ทธ ์ธ ์ํคํ ์ณ ์ ๋ณ๊ฒฝ์ ์ ์์
- objective : CE loss
- baseline : Transformer, RoFormer
- data : WMT, language modeling benchmark
- evaluation : BLEU (MT), perplexity (LM)
- result : ๊ธด ์ํ์ค์์ ํฐ ์๋/๋ฉ๋ชจ๋ฆฌ ๊ฐ์ , ์ฑ๋ฅ์ ์ฝ๊ฐ ๊ฐ์ํ๊ฑฐ๋ ์ ์ฌํ ์์ค
- contribution : attention์ kernel๋ก ์ฌํด์, O(N) linear attention ์ ์, transformer๊ฐ RNN์ฒ๋ผ ๋์ํจ์ ๋ณด์
- etc. : causal masking์ด prefix sum ๊ตฌ์กฐ์ ์์ฐ์ค๋ฝ๊ฒ ํฌํจ๋จ, layerwise parallelism ๊ฐ๋ฅ
Details
- conversation with chatGPT: link
- X$
3.1 Transformer
- $f_l(.)$์ ๊ทธ๋ฅ FFN
- $A_l(.)$ self-attnetion
์ ๊ธฐ์ softmax term์ ๊ทธ๋ฅ ์ ์ฌ๋ ํจ์ $sim(\cdot)$๋ก ํํํ ์ ์์
3.2. Linearized attention
์ฌ๊ธฐ๊ฐ ๊ฐ์๊ธฐ ํท๊ฐ๋ฆฌ๋๋ฐ, Kernel Trick์ด๋ ๊ฑธ ์ธ๊ฑฐ์. attention์์ $sim(\cdot)$์ “non-negative"์ฌ์ผ ํ๋ค๋ ์ ์ฝ ๋ฐ์ ์์ ๊ทธ๋ ๋ค๋ฉด ๋ชจ๋ kernel ์ค์ , $k(x,y) : \mathbb{R}^{2 \times F} -> \mathbb{R}_{+}$ ๋ฅผ ํฌํจํ ์ ์๊ฒ ๋จ
๊ทธ๋ฐ “imaginery kernel”($k$)์ด ์๋ค ์น๊ณ , feature ํํ $\phi(x)$์ ๋ํด eq (2)๋ฅผ ๋ค์ ์ฐ๋ฉด
์์์ $\sum _j$๋ j์ ๋ํ ๊ฐ์ด๊ธฐ ๋๋ฌธ์ $\phi(Q_i)^T$๋ฅผ ๋๊ธธ ์ ์๊ณ , ๊ทธ๋ฌ๋ฉด ์๋์ ๊ฐ์ด ์์ด ๋จ
์ด ๋ feature map $\phi(\cdot)$์ $Q$, $K$ ํ๋ ฌ์ row-wise๋ก ์ฐ์ฐ๋จ
eq (6)์ ๊ดํธ ์์ $\phi(X)^T\in \mathbb{R}^{D\times N}$, $\phi(X)^T\in \mathbb{R}^{N\times D}$ ์ด์ด์ $O(N)$์ ์๊ฐ, ๊ณต๊ฐ ๋ณต์ก๋๋ฅผ ๊ฐ์ง๊ฒ ๋จ. (๊ณต๊ฐ ๋ณต์ก๋์ ์๊ฐ ๋ณต์ก๋๊ฐ ํท๊ฐ๋ฆฌ๋ค..) –
๊ทธ ์ด์ ๋ ์ฐ๋ฆฌ๊ฐ KV, K๋ฅผ ํ๋ฒ ์ ์ฅํ๊ณ ์ฌ์ฌ์ฉํ ๊ฒ์ด๊ธฐ ๋๋ฌธ.
Feature maps and computational cost
Kernel์ ์ด๋ค ๊ฒ์ ์ฌ์ฉํ๋์ ๋ฐ๋ผ computational cost๊ฐ ๋ฌ๋ผ์ง๊ธฐ ๋๋ฌธ์ elu ํจ์๋ฅผ ์ ํ
relu over elu๋ฅผ ์ฌ์ฉํ ๊ฒ์ 0 ์ดํ์ผ ๋๋ gradient๊ฐ ํ๋ ์ผ๋ฉด ์ข๊ฒ ์ด์
3.3 Causal Masking
Transformer์ Causal masking์ ์ฌ๊ธฐ์ ์ด๋ป๊ฒ ๊ตฌํ ์ ์๋ ์ด๊ฒ์ summation์ ๋ชจ๋ j์ ๋ํด ํ๋๊ฒ ์๋๋ผ $i$๊น์ง ํ๋๋ก ๋ฐ๊พธ๋ฉด ๋จ
(์ด์ ์ ์)
(w/ causal masking)
์ฐ๋ฆฌ๋ $S_{i-1}$๋ก ๋ถํฐ $S_{i}$๋ฅผ ๊ณ์ฐํ ์ ์์. ์๋ํ๋ฉด ๋์ ํฉ์ด๊ธฐ ๋๋ฌธ์. ์ฌ๊ธฐ์ ์ฒ์ ์ฝ์ ๋ ํท๊ฐ๋ ธ๋๋ฐ, Inference ์์ ๋์ ํฉ์ผ๋ก ํ๋ค๋ ๊ฒ์ด๊ณ ์ค์ ํ์ต ๋๋ ์๋์ transformer์ฒ๋ผ causal mask๋ฅผ ์ ์ฉ.
3.3.1 Gradient Computation
gradient๋ฅผ ๋์ด๋ธํ๊ฒ ๊ตฌํ๋ฉด ๋ $O(N^2)$ ๋ณต์ก๋๊ฐ ๋์ง๋ง ์ ๊ตฌํด์ ์๋ Linearํ๊ฒ ํจ
3.3.2 Training and Inference
Transformer ๋๋น ์ข์ ์ ์ Inference ์์ QK๋ฅผ ์๊ฐ์ง๊ณ ์์ด๋ ๋์ด์ ๋ฉ๋ชจ๋ฆฌ๊ฐ seq len์ ๋น๋กํ์ฌ ๋์ด๋์ง ์์. ์ฆ train, inference์ ์ข์ ์ ์ ๋ค ๊ฐ์ ธ์ด
3.4. Transformers are RNNs
Experiment
์คํต ใ