image

paper , code/data

TL;DR

  • I read this because.. : ์–ธ๊ธ‰๋˜์–ด. step์„ ์–ด๋–ป๊ฒŒ ๋‚˜๋ˆˆ๋‹ค๋Š”๊ฑด์ง€ ๊ถ๊ธˆํ•ด์„œ ์ฝ์Œ.
  • task : LLM in reasoning
  • problem : DPO๊ฐ€ ๋„๋ฆฌ ์“ฐ์ด๊ณ  ์žˆ์œผ๋‚˜ long-context์—์„œ ์„ฑ๋Šฅ ํ–ฅ์ƒ์€ ์ œํ•œ๋จ
  • idea : ๊ธด reasoning์ด ์žˆ๋Š” ๊ฒฝ์šฐ ํ‹€๋ฆฐ step์— ๋Œ€ํ•ด์„œ win / lose step์„ ์ตœ๋Œ€ํ™”ํ•˜๋Š” DPO loss
  • architecture : Qwen2, Qwen1.5 Meta-Llama-3-70B, deepseek-math-7b-base
  • objective : step-DPO (proposed)
  • baseline : SFT, DPO
  • data : ์ฒ˜์Œ์œผ๋กœ ํ‹€๋ฆฐ step์ด ์ €์žฅ๋˜์–ด ์žˆ๋Š” 374K pair ๋ฐ์ดํ„ฐ(proposed), AQuA
  • evaluation : MATH, GSM8K, AIME, Odyssey-MATH
  • result : DPO๋ณด๋‹ค ๋‚˜์€ ์„ฑ๋Šฅ. GPT-4-1106, Claude-3-Opus, Gemini-1.5-Pro๋ฅผ ์ด๊ฒผ๋‹ค๊ณ  ํ•จ.
  • contribution : data ๊ณต๊ฐœ. ์ด๋Ÿฐ ๋ฅ˜๊ฐ€ ๋งŽ์€๊ฒƒ ๊ฐ™์€๋ฐ ์ด๊ฒŒ ์ฒ˜์Œ์ธ์ง€๋Š” ๋ชจ๋ฅด๊ฒ ์Œ
  • etc. :

Details

Performance

image

motivation

์ด ๋…ผ๋ฌธ์—์„œ ๋งํ•˜๋Š” SFT์˜ ๋‹จ์ ์€ desirable output ๋ฟ ์•„๋‹ˆ๋ผ undesirable output์— ๋Œ€ํ•œ likelihood๋„ ๋†’์ธ๋‹ค๋Š” ์ ์ž„ -> prone to hallucination ์ด๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•ด undesriable supervision์„ ์ฃผ๋Š”๊ฒŒ RLHF์ธ๋ฐ DPO์˜ ๊ฒฝ์šฐ long sequence output์— ๋Œ€ํ•ด ํšจ๊ณผ๊ฐ€ ์ข‹์ง€ ์•Š๋‹ค๊ณ  ํ•จ. (finegrained process supervision์ด ์—†์–ด์„œ๋ผ๊ณ  ํ‘œํ˜„)

image

Step-DPO

image

์ „์ฒด ์‹œํ€€์Šค๊ฐ€ ์•„๋‹ˆ๋ผ ํ‹€๋ฆฐ step์— ๋Œ€ํ•ด์„œ win – lose margin์„ ์ตœ๋Œ€ํ™”ํ•˜๋„๋ก image

  • $s_i$ : i๋ฒˆ์งธ reasoning step
  • $x$ : prompt
  • $k$ : ์ตœ์ดˆ๋กœ ํ‹€๋ฆฐ step

In-distribtuion data construction

์•„๋ž˜์™€ ๊ฐ™์ด ๋งŒ๋“œ๋Š”๊ฒŒ ๋ชฉํ‘œ image

ํŒŒ์ดํ”„๋ผ์ธ image

  • error collection problems x ์™€ gt answer $\hat{y}$๋ฅผ ๋ชจ์Œ. reference model $\pi_{ref}$๋ฅผ ๊ฐ€์ง€๊ณ  step-wise CoT preifx๋กœ ์‹คํ–‰ํ•ด์„œ step์œผ๋กœ ๋‚˜๋ˆ” final answer y๊ฐ€ gt answer๊ฐ€ ๋‹ค๋ฅธ ๊ฒƒ๋“ค์„ ๋ชจ์Œ.

  • step localization reasoning step $y=s_1, s_2, … , s_n$์—์„œ ์ฒ˜์Œ์œผ๋กœ ํ‹€๋ฆฐ $k$๋ฅผ ์ฐพ์Œ. (manually or gpt-4๋ฅผ ํ†ตํ•ด) ํ‹€๋ฆฐ step k์˜ ์—๋Ÿฌ๋ฅผ $s_{lose}$๋กœ ์„ ์ •

  • rectification ๋งž๋Š” ressoning step $s_{1~{k-1}}$์„ ์ฃผ์–ด์ฃผ๊ณ  ์—ฌ๋Ÿฌ๋ฒˆ reference model์— inferํ•ด์„œ ์—ฌ๋Ÿฌ๊ฐœ ๊ตฌํ•จ image

์ด์ค‘์— final answer๊ฐ€ gt์™€ ๋งž๋Š” ๊ฑธ $s_{win}$์œผ๋กœ ์„ ์ •ํ•จ. ์ด๋•Œ ์ •๋‹ต์ด ๋งž๋”๋ผ๋„ ๊ณผ์ •์ด ํ‹€๋ฆด ์ˆ˜ ์žˆ๋Š”๋ฐ ์ด๋Š” manually or gpt-4๋กœ ์ •์ œํ•จ (๊ทธ๋ฆผ์—์„œ๋Š” ์ƒ๋žต๋˜์–ด ์žˆ์Œ)

Result

  • ์ „์ฒด 374K๋ฅผ ๋ชจ์•˜๊ณ , ์ด์ค‘ 299K๊ฐ€ SFT ๋ฐ์ดํ„ฐ๋กœ ์“ฐ์˜€๊ณ  ๋‚˜๋จธ์ง€ 75K๋Š” Step-DPO๋กœ ์“ฐ์ž„
    • SFT๋Š” 3 or 2 ์—ํญ
    • Step-DPO๋Š” 8 or 4 ์—ํญ ๋Œ๋ฆผ
  • SFT dataset์— ์ถ”๊ฐ€์ ์œผ๋กœ AQuA ๋ฐ์ดํ„ฐ ์…‹ ์‚ฌ์šฉํ•จ

image

image

Ablation

  • DPO vs Step-DPO image

  • in-distribution vs out-distribution image

์‚ฌ์šฉํ•˜๋Š” ๋ฐ์ดํ„ฐ๊ฐ€ ์šฐ๋ฆฌ๊ฐ€ ํ•™์Šตํ•œ ๋ชจ๋ธ์˜ inference ๊ฒฐ๊ณผ์ธ๊ฒŒ ์ค‘์š”ํ•˜๋‹ค๊ณ  ํ•จ