image

paper

TL;DR

  • I read this because.. : #113 ์— ์ด์–ด์„œ efficient finetuning ์‹œ๋ฆฌ์ฆˆ ๋ฌผ
  • task : LLM finetuning
  • problem : finetuning์€ ๋น„ํšจ์œจ์ . adaptor๋Š” ์–ด์จŒ๋“  ๋ ˆ์ด์–ด๊ฐ€ ์ค‘๊ฐ„์— ์ถ”๊ฐ€๋˜๊ธฐ ๋•Œ๋ฌธ์— latency์— ์˜ํ–ฅ.
  • idea : weight์˜ ์—…๋ฐ์ดํŠธ ๋ถ„์„ low-rank ๋กœ ๊ทผ์‚ฌํ•˜์—ฌ ์›๋ž˜ ํŒŒ๋ผ๋ฏธํ„ฐ์— ๋”ํ•˜์ž!
  • architecture : RoBERTa, DeBERTa, GPT-2, GPT-3
  • objective : ce loss
  • baseline : finetuning / adaptors / pre-layer
  • data : GLUE, WikiSQL, MultiNLI
  • result : ํ›จ์”ฌ ๋” ์ž‘์€ trainable parameter๋กœ ๋” ๋‚˜์€ ์„ฑ๋Šฅ
  • contribution : latency ์ถ”๊ฐ€ ์—†์ด ํšจ์œจ์ ์ธ finetuning
  • limitation / things I cannot understand :

Details

  • preliminaries : Parameter-Efficient Transfer Learning for NLP Adaptor๋ฅผ ์ œ์•ˆ. finetuning์€ ๋ชจ๋“  ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ํ•™์Šตํ•˜๊ณ  ์ €์žฅํ•ด์•ผ๋˜์–ด์„œ ๋น„ํšจ์œจ์ . feature-extraction์€ ์„ฑ๋Šฅ์˜ ํ•œ๊ณ„. downstream task๋“ค์„ ๋” ์ ์€ ํŒŒ๋ผ๋ฏธํ„ฐ๋กœ ํ•™์Šตํ•˜๋Š” adaptor ์ œ์•ˆ. ์ด ๋…ผ๋ฌธ์—์„œ๋Š” ํŠธ๋žœ์Šคํฌ๋จธ ๋ ˆ์ด์–ด์— ๋‘๊ฐœ์˜ adaptor layer๋ฅผ ๋„ฃ์Œ. image
image
  • architecture image

๊ธฐ๋ณธ์ ์ธ ์•„์ด๋””์–ด๋Š” denseํ•œ layer๊ฐ€ ๋” ๋‚ฎ์€ rank๋กœ decompose๋  ์ˆ˜ ์žˆ๋‹ค๋Š” ์•„์ด๋””์–ด. ์–ด๋–ค weight W์˜ update ๋ถ„์ธ $\Delta W$๋ฅผ $BA$ $B\in\mathbb{R}^{d \times r}$, $A\in\mathbb{R}^{r \times k}$๋กœ ๊ทผ์‚ฌํ•ด์„œ forward๋ฅผ ์•„๋ž˜์™€ ๊ฐ™์ด ๋งŒ๋“ฆ

image

์ด๋•Œ A๋Š” random gaussian์œผ๋กœ B๋Š” zero๋กœ initialize๋จ. ์ฆ‰ ์ดˆ๊ธฐ BA๋Š” 0์ด ๋จ. $\Delta W$๋Š” $\alpha / \gamma$๋กœ ์—…๋ฐ์ดํŠธ ๋˜๋Š”๋ฐ $\alpha$๊ฐ€ ์ผ์ข…์˜ learning rate์ฒ˜๋Ÿผ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ์ฒ˜๋Ÿผ ์‚ฌ์šฉํ•จ. LoRA๋ฅผ attention์„ ์œ„ํ•œ weight๋“ค์ธ $W_q$, $W_k$, $W_v$, $W_o$์—๋งŒ ์ ์šฉํ•˜๊ณ  MLP์—๋Š” ์ ์šฉํ•˜์ง€ ์•Š์Œ.

image

์ œํ•œ๋œ ํŒŒ๋ผ๋ฏธํ„ฐ ์ œ์•ฝ ์•ˆ์—์„œ $W_q$๋งŒ ์ ์šฉํ•˜๋Š” ๊ฒƒ๋ณด๋‹ค rank 4๋”๋ผ๋„ ๋‘˜๋‹ค ์ ์šฉํ•˜๋Š”๊ฒŒ ์ข‹์•˜๊ณ  ์…‹๋‹ค ์ ์šฉํ•˜๋Š”๊ฒŒ ๊ฐ€์žฅ ์ข‹์•˜์Œ. image

๋งค์šฐ ๋‚ฎ์€ rank์—์„œ๋„ ์ž˜ ์ž‘๋™ํ–ˆ๊ณ  ์ด๋Š” update matrix $\Delta W$ ๊ฐ€ ๋งค์šฐ ๋‚ฎ์€ intrinsic matrix๋ฅผ ๊ฐ€์ง€๊ณ  ์žˆ๋‹ค๋Š” ๋œป์ž„.

  • inference latency image

  • results image

image image