TL;DR
- I read this because.. : dense RLHF ๊ด๋ จ์ธ๊ฐ ์ถ์ด์
- task : RLHF
- problem : DPO๋ ์๋์ ์ธ log prob์ ๋ํ loss๋ฅผ ๋ถ๊ฐํ๊ธฐ ๋๋ฌธ์ edit distance๊ฐ ์ ์ pair์ ๊ฒฝ์ฐ ํ๋ฆฌ์ง ์์ ๋ถ๋ถ์ ๋ํด์๋ log prob์ด ๋ฎ์์ง๋ ๊ฒ ๊ด์ฐฐ
- idea : preferred answer์ ๋ํ log prob์ด ๋๋ฌด ๋ฎ์์ง์ง ์๋๋ก penalty ๋ถ๊ฐ
- input/output : query -> answer
- architecture : Llama-2-7B-Chat, Bagel-34B-v0.2, MoMo-72b-lora-1.8.7-DPO
- objective : proposed DPOP loss(DPO loss + $\max\left(0, \log \frac{\pi_{\text{ref}}(y_w|x)}{\pi_{\theta}(y_w|x)}\right)$ )
- baseline : DPO, IPO, SLiC
- data : GSM8K, MetaMath, ARC, Hellaswag๋ฅผ ์ผ๋ถ๋ฌ ํ๋ฆฐ pair๋ฅผ ๋ง๋๋ ์์ผ๋ก ํด์ ๋ค์ ๋ง๋ฆ.
- evaluation : GSM8K / ARC / Hellaswag test split
- result : edit distance๊ฐ ๋ฎ์/๋์ ๋ฐ์ดํฐ์ ๋ชจ๋์์ ๋ฒ ์ด์ค๋ผ์ธ๋ณด๋ค ๋์ ์ฑ๋ฅ.
- contribution : ์ด๋ค ์ํฉ์์ ๋ฌธ์ ๊ฐ ์๊ธด๊ฑด์ง ์ง๊ด์ ์ผ๋ก ์ดํดํ๊ธฐ ์ฝ๊ณ ํด๊ฒฐ ๋ฐฉ๋ฒ๋ ์ง๊ด์ ์
- etc. : dense RLHF๋์ ์๊ด ์์์ง๋ง ์๊ด ์๋๊ฑธ๋ก..?! ใ ใ
Details
motivation
DPO์ loss๋ ์์ ๊ฐ์ ์ด๋ ์ ์๋ค์ด ๊ฐ์กฐํ๋ ๋ฌธ์ ๋ loss๊ฐ ์๋์ ์ธ log prob์๋ง ์์กดํ๋ค๋ ๊ฒ์. (๋ ผ๋ฌธ์์ $\pi_{ratio}$๋ก ํํ) ์ด ์๋์ ์ธ ๋น์จ์ด preferred ๊ฐ disprefered๋ณด๋ค ๋๊ธฐ๋ง ํ๋ฉด ๋๋๊น $y_w$์์๋ $\pi_{ratio}(y_w)$๋ ๊ณ์ ๋ฎ์์ง ์ ์์. ์ด๊ฒ์ด ์ด๋ ์ํฉ์ ๋ํด์ ๋ฌธ์ ๊ฐ ๋๋๋ฉด edit distance๊ฐ ์ ์ pair์ ๋ํด์ DPO๋ฅผ ํ ๋์.
DPO loss์ ๋ํด Gradient๋ฅผ ๊ตฌํ๋ฉด ์๋์ ๊ฐ์
์ด ๋ ๋
ผ์์ ํธ์์ฑ์ ์ํด ์ฒซ๋ฒ์งธ ํ ํฐ์์๋ง $y_w$, $y_l$์ด ๋ค๋ฅด๋ค๊ณ ํ์. ๊ทธ๋ฌ๋ฉด ๊ทธ ์ดํ ํ ํฐ $k$์ ๋ํ gradient๋ ์๋์ ๊ฐ๋ค.
- $s_j^{x}$ ๋ x๊ฐ ์ฃผ์ด์ก์ ๋ j ๋ฒ์งธ ํ ํฐ์ ์์ธกํ๋ ํ๋ฅ
์ฐ๋ฆฌ๋ ๋ณดํต DPO๋ฅผ SFT๊ฐ ์๋ฃ๋ weight์์ ์์ํ๊ธฐ ๋๋ฌธ์ ํ๋ฆฐ ํ ํฐ ์ดํ์ ๋์ค๋ ํ ํฐ์ ๋ํด์๋ log prob์ด ๋ฎ์ ์ ๋ฐ์ ์์. ๊ทธ๋ฌ๋ฉด ๋ค์ ํ ํฐ๋ค์ ์ฌ์ค์ ๋ง๋ ํ ํฐ์์๋ ๋ถ๊ตฌํ๊ณ ๋์ log prob์ ์ฐจ์ด๊ฐ ์๊ธฐ๊ธฐ ๋๋ฌธ์ Loss๊ฐ ๋ถ๊ฐ๋จ. ์ฆ ํ๋ฆฐ ํ ํฐ์ ๋ํ ํ๋ฅ ๋ถํฌ๋ ๋ง๊ฒ ์์ ๋์ง๋ง ๊ทธ ์ดํ์ ๋ง๋ ํ ํฐ์ ๋ํด์๋ ๋ถํ์ํ๊ฒ log prob์ด ๋ฎ์์ง๊ฒ ๋๋ ๊ฒ์ด ๋ฌธ์
Propose DPOP
penalty term ์ถ๊ฐ. prefered answer์ ๋ํด $\pi_{ref}$๋ณด๋ค ๋ฎ์์ง์ง ์๋๋ก.