Image

paper , blog , code

TL;DR

  • I read this because.. : 4์›”์€ linear transformer – ์ฒซ๋ฒˆ์งธ ๋…ผ๋ฌธ
  • task : autoregressive sequence modeling, language modeling, machine translation
  • problem : self-attention์ด ๋ชจ๋“  ํ† ํฐ ์Œ์„ ๋น„๊ตํ•ด์„œ ์‹œ๊ฐ„/๋ฉ”๋ชจ๋ฆฌ O(Nยฒ), ๊ธด ์‹œํ€€์Šค์—์„œ ๋น„ํšจ์œจ์ 
  • idea : softmax attention์„ kernel ํ˜•ํƒœ $\phi(Q)\phi(K)^T$๋กœ ๋ฐ”๊ฟ”์„œ ๊ฒฐํ•ฉ๋ฒ•์น™์œผ๋กœ ์žฌ๋ฐฐ์—ด, ์ด๋ฅผ ํ†ตํ•ด cumulative sum์„ ํ•จ
  • input/output : token -> token
  • architecture : softmax ๋Œ€์‹  kernel with elu function์œผ๋กœ ๋ฐ”๊ฟˆ. ๊ทธ ์™ธ ์•„ํ‚คํ…์ณ ์ƒ ๋ณ€๊ฒฝ์ ์€ ์—†์Œ
  • objective : CE loss
  • baseline : Transformer, RoFormer
  • data : WMT, language modeling benchmark
  • evaluation : BLEU (MT), perplexity (LM)
  • result : ๊ธด ์‹œํ€€์Šค์—์„œ ํฐ ์†๋„/๋ฉ”๋ชจ๋ฆฌ ๊ฐœ์„ , ์„ฑ๋Šฅ์€ ์•ฝ๊ฐ„ ๊ฐ์†Œํ•˜๊ฑฐ๋‚˜ ์œ ์‚ฌํ•œ ์ˆ˜์ค€
  • contribution : attention์„ kernel๋กœ ์žฌํ•ด์„, O(N) linear attention ์ œ์•ˆ, transformer๊ฐ€ RNN์ฒ˜๋Ÿผ ๋™์ž‘ํ•จ์„ ๋ณด์ž„
  • etc. : causal masking์ด prefix sum ๊ตฌ์กฐ์— ์ž์—ฐ์Šค๋Ÿฝ๊ฒŒ ํฌํ•จ๋จ, layerwise parallelism ๊ฐ€๋Šฅ

Details

  • conversation with chatGPT: link
  • X$

3.1 Transformer

Image
  • $f_l(.)$์€ ๊ทธ๋ƒฅ FFN
  • $A_l(.)$ self-attnetion
Image

์ €๊ธฐ์„œ softmax term์„ ๊ทธ๋ƒฅ ์œ ์‚ฌ๋„ ํ•จ์ˆ˜ $sim(\cdot)$๋กœ ํ‘œํ˜„ํ•  ์ˆ˜ ์žˆ์Œ

Image

3.2. Linearized attention

์—ฌ๊ธฐ๊ฐ€ ๊ฐ‘์ž๊ธฐ ํ—ท๊ฐˆ๋ฆฌ๋Š”๋ฐ, Kernel Trick์ด๋ž€ ๊ฑธ ์“ธ๊ฑฐ์ž„. attention์—์„œ $sim(\cdot)$์€ “non-negative"์—ฌ์•ผ ํ•œ๋‹ค๋Š” ์ œ์•ฝ ๋ฐ–์— ์—†์Œ ๊ทธ๋ ‡๋‹ค๋ฉด ๋ชจ๋“  kernel ์ค‘์— , $k(x,y) : \mathbb{R}^{2 \times F} -> \mathbb{R}_{+}$ ๋ฅผ ํฌํ•จํ•  ์ˆ˜ ์žˆ๊ฒŒ ๋จ

๊ทธ๋Ÿฐ “imaginery kernel”($k$)์ด ์žˆ๋‹ค ์น˜๊ณ , feature ํ‘œํ˜„ $\phi(x)$์— ๋Œ€ํ•ด eq (2)๋ฅผ ๋‹ค์‹œ ์“ฐ๋ฉด

Image

์œ„์—์„œ $\sum _j$๋Š” j์— ๋Œ€ํ•œ ๊ฐ’์ด๊ธฐ ๋•Œ๋ฌธ์— $\phi(Q_i)^T$๋ฅผ ๋„˜๊ธธ ์ˆ˜ ์žˆ๊ณ , ๊ทธ๋Ÿฌ๋ฉด ์•„๋ž˜์™€ ๊ฐ™์ด ์‹์ด ๋จ Image

์ด ๋•Œ feature map $\phi(\cdot)$์€ $Q$, $K$ ํ–‰๋ ฌ์— row-wise๋กœ ์—ฐ์‚ฐ๋จ eq (6)์˜ ๊ด„ํ˜ธ ์•ˆ์€ $\phi(X)^T\in \mathbb{R}^{D\times N}$, $\phi(X)^T\in \mathbb{R}^{N\times D}$ ์ด์–ด์„œ $O(N)$์˜ ์‹œ๊ฐ„, ๊ณต๊ฐ„ ๋ณต์žก๋„๋ฅผ ๊ฐ€์ง€๊ฒŒ ๋จ. (๊ณต๊ฐ„ ๋ณต์žก๋„์™€ ์‹œ๊ฐ„ ๋ณต์žก๋„๊ฐ€ ํ—ท๊ฐˆ๋ฆฌ๋„ค..) – ๊ทธ ์ด์œ ๋Š” ์šฐ๋ฆฌ๊ฐ€ KV, K๋ฅผ ํ•œ๋ฒˆ ์ €์žฅํ•˜๊ณ  ์žฌ์‚ฌ์šฉํ•  ๊ฒƒ์ด๊ธฐ ๋•Œ๋ฌธ. Image

Feature maps and computational cost

Kernel์„ ์–ด๋–ค ๊ฒƒ์„ ์‚ฌ์šฉํ•˜๋ƒ์— ๋”ฐ๋ผ computational cost๊ฐ€ ๋‹ฌ๋ผ์ง€๊ธฐ ๋•Œ๋ฌธ์— elu ํ•จ์ˆ˜๋ฅผ ์„ ํƒ Image

relu over elu๋ฅผ ์‚ฌ์šฉํ•œ ๊ฒƒ์€ 0 ์ดํ•˜์ผ ๋•Œ๋„ gradient๊ฐ€ ํ˜๋ €์œผ๋ฉด ์ข‹๊ฒ ์–ด์„œ

3.3 Causal Masking

Transformer์˜ Causal masking์„ ์—ฌ๊ธฐ์„  ์–ด๋–ป๊ฒŒ ๊ตฌํ•  ์ˆ˜ ์žˆ๋ƒ ์ด๊ฒƒ์€ summation์„ ๋ชจ๋“  j์— ๋Œ€ํ•ด ํ•˜๋Š”๊ฒŒ ์•„๋‹ˆ๋ผ $i$๊นŒ์ง€ ํ•˜๋„๋ก ๋ฐ”๊พธ๋ฉด ๋จ

(์ด์ „์˜ ์‹)

Image

(w/ causal masking) Image

์šฐ๋ฆฌ๋Š” $S_{i-1}$๋กœ ๋ถ€ํ„ฐ $S_{i}$๋ฅผ ๊ณ„์‚ฐํ•  ์ˆ˜ ์žˆ์Œ. ์™œ๋ƒํ•˜๋ฉด ๋ˆ„์ ํ•ฉ์ด๊ธฐ ๋•Œ๋ฌธ์—. ์—ฌ๊ธฐ์„œ ์ฒ˜์Œ ์ฝ์„ ๋•Œ ํ—ท๊ฐˆ๋ ธ๋Š”๋ฐ, Inference ์‹œ์— ๋ˆ„์ ํ•ฉ์œผ๋กœ ํ•œ๋‹ค๋Š” ๊ฒƒ์ด๊ณ  ์‹ค์ œ ํ•™์Šต ๋•Œ๋Š” ์›๋ž˜์˜ transformer์ฒ˜๋Ÿผ causal mask๋ฅผ ์ ์šฉ.

3.3.1 Gradient Computation

gradient๋ฅผ ๋‚˜์ด๋ธŒํ•˜๊ฒŒ ๊ตฌํ•˜๋ฉด ๋˜ $O(N^2)$ ๋ณต์žก๋„๊ฐ€ ๋˜์ง€๋งŒ ์ž˜ ๊ตฌํ•ด์„œ ์–˜๋„ Linearํ•˜๊ฒŒ ํ•จ

Image

3.3.2 Training and Inference

Transformer ๋Œ€๋น„ ์ข‹์€ ์ ์€ Inference ์‹œ์— QK๋ฅผ ์•ˆ๊ฐ€์ง€๊ณ  ์žˆ์–ด๋„ ๋˜์–ด์„œ ๋ฉ”๋ชจ๋ฆฌ๊ฐ€ seq len์— ๋น„๋ก€ํ•˜์—ฌ ๋Š˜์–ด๋‚˜์ง€ ์•Š์Œ. ์ฆ‰ train, inference์˜ ์ข‹์€ ์ ์„ ๋‹ค ๊ฐ€์ ธ์˜ด Image

3.4. Transformers are RNNs

Image

Experiment

์Šคํ‚ต ใ…Ž