TL;DR
- I read this because.. : ์ธ๊ธ๋์ด. step์ ์ด๋ป๊ฒ ๋๋๋ค๋๊ฑด์ง ๊ถ๊ธํด์ ์ฝ์.
- task : LLM in reasoning
- problem : DPO๊ฐ ๋๋ฆฌ ์ฐ์ด๊ณ ์์ผ๋ long-context์์ ์ฑ๋ฅ ํฅ์์ ์ ํ๋จ
- idea : ๊ธด reasoning์ด ์๋ ๊ฒฝ์ฐ ํ๋ฆฐ step์ ๋ํด์ win / lose step์ ์ต๋ํํ๋ DPO loss
- architecture : Qwen2, Qwen1.5 Meta-Llama-3-70B, deepseek-math-7b-base
- objective : step-DPO (proposed)
- baseline : SFT, DPO
- data : ์ฒ์์ผ๋ก ํ๋ฆฐ step์ด ์ ์ฅ๋์ด ์๋ 374K pair ๋ฐ์ดํฐ(proposed), AQuA
- evaluation : MATH, GSM8K, AIME, Odyssey-MATH
- result : DPO๋ณด๋ค ๋์ ์ฑ๋ฅ. GPT-4-1106, Claude-3-Opus, Gemini-1.5-Pro๋ฅผ ์ด๊ฒผ๋ค๊ณ ํจ.
- contribution : data ๊ณต๊ฐ. ์ด๋ฐ ๋ฅ๊ฐ ๋ง์๊ฒ ๊ฐ์๋ฐ ์ด๊ฒ ์ฒ์์ธ์ง๋ ๋ชจ๋ฅด๊ฒ ์
- etc. :
Details
Performance
motivation
์ด ๋ ผ๋ฌธ์์ ๋งํ๋ SFT์ ๋จ์ ์ desirable output ๋ฟ ์๋๋ผ undesirable output์ ๋ํ likelihood๋ ๋์ธ๋ค๋ ์ ์ -> prone to hallucination ์ด๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด undesriable supervision์ ์ฃผ๋๊ฒ RLHF์ธ๋ฐ DPO์ ๊ฒฝ์ฐ long sequence output์ ๋ํด ํจ๊ณผ๊ฐ ์ข์ง ์๋ค๊ณ ํจ. (finegrained process supervision์ด ์์ด์๋ผ๊ณ ํํ)
Step-DPO
์ ์ฒด ์ํ์ค๊ฐ ์๋๋ผ ํ๋ฆฐ step์ ๋ํด์ win – lose margin์ ์ต๋ํํ๋๋ก
- $s_i$ : i๋ฒ์งธ reasoning step
- $x$ : prompt
- $k$ : ์ต์ด๋ก ํ๋ฆฐ step
In-distribtuion data construction
์๋์ ๊ฐ์ด ๋ง๋๋๊ฒ ๋ชฉํ
ํ์ดํ๋ผ์ธ
error collection problems x ์ gt answer $\hat{y}$๋ฅผ ๋ชจ์. reference model $\pi_{ref}$๋ฅผ ๊ฐ์ง๊ณ step-wise CoT preifx๋ก ์คํํด์ step์ผ๋ก ๋๋ final answer y๊ฐ gt answer๊ฐ ๋ค๋ฅธ ๊ฒ๋ค์ ๋ชจ์.
step localization reasoning step $y=s_1, s_2, … , s_n$์์ ์ฒ์์ผ๋ก ํ๋ฆฐ $k$๋ฅผ ์ฐพ์. (manually or gpt-4๋ฅผ ํตํด) ํ๋ฆฐ step k์ ์๋ฌ๋ฅผ $s_{lose}$๋ก ์ ์
rectification ๋ง๋ ressoning step $s_{1~{k-1}}$์ ์ฃผ์ด์ฃผ๊ณ ์ฌ๋ฌ๋ฒ reference model์ inferํด์ ์ฌ๋ฌ๊ฐ ๊ตฌํจ
์ด์ค์ final answer๊ฐ gt์ ๋ง๋ ๊ฑธ $s_{win}$์ผ๋ก ์ ์ ํจ. ์ด๋ ์ ๋ต์ด ๋ง๋๋ผ๋ ๊ณผ์ ์ด ํ๋ฆด ์ ์๋๋ฐ ์ด๋ manually or gpt-4๋ก ์ ์ ํจ (๊ทธ๋ฆผ์์๋ ์๋ต๋์ด ์์)
Result
- ์ ์ฒด 374K๋ฅผ ๋ชจ์๊ณ , ์ด์ค 299K๊ฐ SFT ๋ฐ์ดํฐ๋ก ์ฐ์๊ณ ๋๋จธ์ง 75K๋ Step-DPO๋ก ์ฐ์
- SFT๋ 3 or 2 ์ํญ
- Step-DPO๋ 8 or 4 ์ํญ ๋๋ฆผ
- SFT dataset์ ์ถ๊ฐ์ ์ผ๋ก AQuA ๋ฐ์ดํฐ ์ ์ฌ์ฉํจ
Ablation
DPO vs Step-DPO
in-distribution vs out-distribution
์ฌ์ฉํ๋ ๋ฐ์ดํฐ๊ฐ ์ฐ๋ฆฌ๊ฐ ํ์ตํ ๋ชจ๋ธ์ inference ๊ฒฐ๊ณผ์ธ๊ฒ ์ค์ํ๋ค๊ณ ํจ