image

paper , code

TL;DR

  • I read this because.. : dense RLHF ๊ด€๋ จ์ธ๊ฐ€ ์‹ถ์–ด์„œ
  • task : RLHF
  • problem : DPO๋Š” ์ƒ๋Œ€์ ์ธ log prob์— ๋Œ€ํ•œ loss๋ฅผ ๋ถ€๊ฐ€ํ•˜๊ธฐ ๋•Œ๋ฌธ์— edit distance๊ฐ€ ์ ์€ pair์˜ ๊ฒฝ์šฐ ํ‹€๋ฆฌ์ง€ ์•Š์€ ๋ถ€๋ถ„์— ๋Œ€ํ•ด์„œ๋„ log prob์ด ๋‚ฎ์•„์ง€๋Š” ๊ฒƒ ๊ด€์ฐฐ
  • idea : preferred answer์— ๋Œ€ํ•œ log prob์ด ๋„ˆ๋ฌด ๋‚ฎ์•„์ง€์ง€ ์•Š๋„๋ก penalty ๋ถ€๊ฐ€
  • input/output : query -> answer
  • architecture : Llama-2-7B-Chat, Bagel-34B-v0.2, MoMo-72b-lora-1.8.7-DPO
  • objective : proposed DPOP loss(DPO loss + $\max\left(0, \log \frac{\pi_{\text{ref}}(y_w|x)}{\pi_{\theta}(y_w|x)}\right)$ )
  • baseline : DPO, IPO, SLiC
  • data : GSM8K, MetaMath, ARC, Hellaswag๋ฅผ ์ผ๋ถ€๋Ÿฌ ํ‹€๋ฆฐ pair๋ฅผ ๋งŒ๋“œ๋Š” ์‹์œผ๋กœ ํ•ด์„œ ๋‹ค์‹œ ๋งŒ๋“ฆ.
  • evaluation : GSM8K / ARC / Hellaswag test split
  • result : edit distance๊ฐ€ ๋‚ฎ์€/๋†’์€ ๋ฐ์ดํ„ฐ์…‹ ๋ชจ๋‘์—์„œ ๋ฒ ์ด์Šค๋ผ์ธ๋ณด๋‹ค ๋†’์€ ์„ฑ๋Šฅ.
  • contribution : ์–ด๋–ค ์ƒํ™ฉ์—์„œ ๋ฌธ์ œ๊ฐ€ ์ƒ๊ธด๊ฑด์ง€ ์ง๊ด€์ ์œผ๋กœ ์ดํ•ดํ•˜๊ธฐ ์‰ฝ๊ณ  ํ•ด๊ฒฐ ๋ฐฉ๋ฒ•๋„ ์ง๊ด€์ ์ž„
  • etc. : dense RLHF๋ž‘์€ ์ƒ๊ด€ ์—†์—ˆ์ง€๋งŒ ์ƒ๊ด€ ์žˆ๋Š”๊ฑธ๋กœ..?! ใ…‹ใ…‹

Details

motivation

image

DPO์˜ loss๋Š” ์œ„์™€ ๊ฐ™์Œ ์ด๋•Œ ์ €์ž๋“ค์ด ๊ฐ•์กฐํ•˜๋Š” ๋ฌธ์ œ๋Š” loss๊ฐ€ ์ƒ๋Œ€์ ์ธ log prob์—๋งŒ ์˜์กดํ•œ๋‹ค๋Š” ๊ฒƒ์ž„. (๋…ผ๋ฌธ์—์„œ $\pi_{ratio}$๋กœ ํ‘œํ˜„) ์ด ์ƒ๋Œ€์ ์ธ ๋น„์œจ์ด preferred ๊ฐ€ disprefered๋ณด๋‹ค ๋†’๊ธฐ๋งŒ ํ•˜๋ฉด ๋˜๋‹ˆ๊นŒ $y_w$์—์„œ๋„ $\pi_{ratio}(y_w)$๋Š” ๊ณ„์† ๋‚ฎ์•„์งˆ ์ˆ˜ ์žˆ์Œ. ์ด๊ฒƒ์ด ์–ด๋А ์ƒํ™ฉ์— ๋Œ€ํ•ด์„œ ๋ฌธ์ œ๊ฐ€ ๋˜๋ƒ๋ฉด edit distance๊ฐ€ ์ ์€ pair์— ๋Œ€ํ•ด์„œ DPO๋ฅผ ํ•  ๋•Œ์ž„.

image

DPO loss์— ๋Œ€ํ•ด Gradient๋ฅผ ๊ตฌํ•˜๋ฉด ์•„๋ž˜์™€ ๊ฐ™์Œ image

์ด ๋•Œ ๋…ผ์˜์˜ ํŽธ์˜์„ฑ์„ ์œ„ํ•ด ์ฒซ๋ฒˆ์งธ ํ† ํฐ์—์„œ๋งŒ $y_w$, $y_l$์ด ๋‹ค๋ฅด๋‹ค๊ณ  ํ•˜์ž. ๊ทธ๋Ÿฌ๋ฉด ๊ทธ ์ดํ›„ ํ† ํฐ $k$์— ๋Œ€ํ•œ gradient๋Š” ์•„๋ž˜์™€ ๊ฐ™๋‹ค. image

  • $s_j^{x}$ ๋Š” x๊ฐ€ ์ฃผ์–ด์กŒ์„ ๋•Œ j ๋ฒˆ์งธ ํ† ํฐ์„ ์˜ˆ์ธกํ•˜๋Š” ํ™•๋ฅ 

์šฐ๋ฆฌ๋Š” ๋ณดํ†ต DPO๋ฅผ SFT๊ฐ€ ์™„๋ฃŒ๋œ weight์—์„œ ์‹œ์ž‘ํ•˜๊ธฐ ๋•Œ๋ฌธ์— ํ‹€๋ฆฐ ํ† ํฐ ์ดํ›„์— ๋‚˜์˜ค๋Š” ํ† ํฐ์— ๋Œ€ํ•ด์„œ๋Š” log prob์ด ๋‚ฎ์„ ์ˆ˜ ๋ฐ–์— ์—†์Œ. ๊ทธ๋Ÿฌ๋ฉด ๋’ค์˜ ํ† ํฐ๋“ค์€ ์‚ฌ์‹ค์ƒ ๋งž๋Š” ํ† ํฐ์ž„์—๋„ ๋ถˆ๊ตฌํ•˜๊ณ  ๋‘˜์˜ log prob์˜ ์ฐจ์ด๊ฐ€ ์ƒ๊ธฐ๊ธฐ ๋•Œ๋ฌธ์— Loss๊ฐ€ ๋ถ€๊ฐ€๋จ. ์ฆ‰ ํ‹€๋ฆฐ ํ† ํฐ์— ๋Œ€ํ•œ ํ™•๋ฅ  ๋ถ„ํฌ๋Š” ๋งž๊ฒŒ ์ˆ˜์ •๋˜์ง€๋งŒ ๊ทธ ์ดํ›„์˜ ๋งž๋Š” ํ† ํฐ์— ๋Œ€ํ•ด์„œ๋Š” ๋ถˆํ•„์š”ํ•˜๊ฒŒ log prob์ด ๋‚ฎ์•„์ง€๊ฒŒ ๋˜๋Š” ๊ฒƒ์ด ๋ฌธ์ œ

Propose DPOP

image

penalty term ์ถ”๊ฐ€. prefered answer์— ๋Œ€ํ•ด $\pi_{ref}$๋ณด๋‹ค ๋‚ฎ์•„์ง€์ง€ ์•Š๋„๋ก.

Result

image image