image

paper

TL;DR

  • I read this because.. : o1 ๊ด€๋ จ ์˜์ƒ์—์„œ ์–ธ๊ธ‰๋˜์–ด
  • task : math
  • problem : outcome-based vs process-based
  • idea : GSM8K๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ 1) final answer๋งŒ ํ•™์Šต 2) human generated reasoning trace or model generated์— ๊ฐ reasoning step์— ๋Œ€ํ•œ human correctness annotation์„ ์ถ”๊ฐ€ํ•˜์—ฌ ํ•™์Šต.
  • architecture : ours base-70B (๋น„๋ฐ€์ธ๋“ฏ)
  • objective : ce (SFT) / bce (ORM / PRM)
  • baseline : PaLM-540B, Minerva-540B, GPT-J-6B, Codex-175B, InstructGPT-175B, GPT-175B
  • data : GSM8K -> (eval) GSM8K-test, MATH
  • evaluation : final-answer error rate / trace error rate(human annotated). MATH dataset (OOD error rate)
  • result : 1) outcome-, process- ๋ชจ๋‘ final answer error rate๋Š” ๋น„์Šท. 2) process-, outcome- RM ์„ ์‚ฌ์šฉํ•  ๊ฒฝ์šฐ ๋‘˜ ๋‹ค process-based feedback ์ƒ์„ฑ ๊ฐ€๋Šฅ 3) trace error๋ฅผ ์ค„์ด๊ธฐ ์œ„ํ•ด์„œ๋Š” process-based feedback or reward model์ด ํ•„์š”
  • contribution : ๋‹ค์–‘ํ•œ ๋ถ„์„. ์‚ฌ์‹ค ์–ผ๋งˆ๋‚˜ ์ค‘์š”ํ•œ ๋…ผ๋ฌธ์ธ์ง€๋Š” ๋ชจ๋ฅด๊ฒ ์Œ

Details

image

training: overview

  • step: new-line seperated (ํ•œ ์ค„์ด ํ•œ step)
  • answer: last line
  • policy network: “each step” as an action, “all the tokens so far” as an observation
    • train with few-shot prompt, SFT, RL
    • reward model -> rerankingํ•˜๋Š”๋ฐ ์‚ฌ์šฉ๋จ

SFT

  • reasoning trace๊นŒ์ง€ ํ•™์Šต.
  • val loss ๊ฐ€ ์ƒ์Šนํ• ๋•Œ ๊นŒ์ง€ ํ•™์Šต. ๋Œ€๋žต 2์—ํญ

Reward model

  • ORM: #209 ์™€ ๋น„์Šทํ•˜๊ฒŒ final answer๊ฐ€ ๋งž๋Š”์ง€ ํ‹€๋ฆฐ์ง€์— ๋Œ€ํ•ด binary label๋กœ ํ•™์Šต
  • PRM: ์ง€๊ธˆ๊นŒ์ง€์˜ step์ด ๋งž๋Š”์ง€ binary label๋กœ ํ•™์Šต
    • ์ด์— ๋Œ€ํ•œ label์€ human annotated๋กœ ๋ฐ›์Œ.
  • ๋‘ ๊ฐœ ๋ชจ๋‘ ํ˜„์žฌ policy ๋ชจ๋ธ์—์„œ ๋‚˜์˜จ sample์„ ์‚ฌ์šฉํ•˜์—ฌ ํ•™์Šต. (temperature 1.0. K=96).
    • ORM์˜ ๊ฒฝ์šฐ SFT์—์„œ ์‹œ์ž‘, few-shot์˜ ๊ฒฝ์šฐ pretrained lm์—์„œ ์‹œ์ž‘.
    • PRM์˜ ๊ฒฝ์šฐ SFT policy network์—์„œ ๋ฌธ์ œ๋‹น 3๊ฐœ์˜ ์ƒ˜ํ”Œ๋กœ ์–ด๋…ธํ…Œ์ด์…˜ ๋ฐ›์Œ.
    • ์ด๋•Œ ๋ฌธ์ œ๋Š” SFT ์˜ˆ์ธก์ด ํ‹€๋ฆฐ ์• ๋“ค์„ ์œ„์ฃผ๋กœํ•จ.
    • PRM์˜ ๊ฒฝ์šฐ ORM ๋ชจ๋ธ๋กœ ์ดˆ๊ธฐํ™”ํ•˜๊ณ  val loss๊ฐ€ ๋ณ€๋™์ด ์ข€ ์žˆ์–ด์„œ 2000 step ์ด์ „ ์ตœ์  Val Loss๋กœ ์„ ์ •

Decoding

  • 96 samples ๋ฝ‘๊ณ  ๋ช‡๊ฐ€์ง€ ๋””์ฝ”๋”ฉ ๊ธฐ๋ฒ• ์ ์šฉ
  • self-consistency
  • RM weighted decoding (==verifier voting) – RM score๋งŒํผ weighted ํ•ด์„œ voting
    • highest RM score๋กœ ์„ ์ •ํ•˜๋Š” ๊ฒƒ๋ณด๋‹ค ์•ฝ๊ฐ„ ์ข‹์•˜์Œ

RL via Expert Iteration

image

์•ˆ ์ฝ์–ด์„œ ์ž˜ ๋ชจ๋ฅด๊ฒ ์œผ๋‚˜ RL๋กœ ํ•™์Šต๋œ ์• ๋ฅผ Policy๋กœ ์จ์„œ trace๋“ค ๋ฝ‘๊ณ  ์ด๋ฅผ ๋ฐ˜๋ณตํ•˜๋Š”๊ฑธ ๋งํ•˜๋Š”๋“ฏ

  • SFT vs few-shot based
    • initial policy network๋Š” SFT์ด๊ฑฐ๋‚˜ 5-shot prompt๋ฅผ ํ•œ base LM ์ด๊ฑฐ๋‚˜ ์„ ํƒ์˜ ์—ฌ์ง€๊ฐ€ ์žˆ์Œ
image
  • Policy Improvement
    • final-answer RL(a.k.a. self-taught reasoner)
      • ๋ฌธ์ œ๋‹น K๊ฐœ์˜ ์ƒ˜ํ”Œ์„ ๋ฝ‘๊ณ  final-answer์˜ ์ •ํ™•๋„๋กœ ํ•„ํ„ฐ๋ง
      • SFT์˜ ๊ฒฝ์šฐ ๋ฌธ์ œ๋‹น ํ•˜๋‚˜๋งŒ ์„ ํƒ (์ด์œ ๋Š” ์—†์Œ)
    • ORM-RL
      • K๊ฐœ์˜ traces ์ค‘ ORM์ด ๊ฐ€์žฅ ๋†’๊ฒŒ ์ ์ˆ˜๋ฅผ ๋งค๊ธด์• ๋ฅผ ์„ ํƒ
    • PRM-RL
      • K(=96)๊ฐœ์˜ candidate step์„ ๋ฝ‘๊ณ  PRM์—์„œ ๊ฐ€์žฅ ๋†’์€ ์ ์ˆ˜๋ฅผ ๊ฐ€์ง„ ์• ๋ฅผ ์„ ํƒ. final answer์ด๊ฑฐ๋‚˜ 15 ์Šคํ…์ด ๋„˜์–ด๊ฐ€๋ฉด ์ข…๋ฃŒ
      • few-shot base์ผ ๊ฒฝ์šฐ RM์€ ๋งค๋ฒˆ ์ƒˆ๋กœ ํ•™์Šตํ–ˆ๊ณ  SFT์˜ ๊ฒฝ์šฐ RM์€ ๊ณ ์ •ํ•จ

Data annotation

  • stepwise label์˜ ๊ฒฝ์šฐ ์ƒ์„ฑ๋œ ๋ชจ๋ธ์—์„œ ์ฒซ๋ฒˆ์งธ๋กœ ํ‹€๋ฆฐ step์„ ์ฐพ์œผ๋ผ๊ณ  ํ•จ. ์ด ๊ธฐ์ค€์€ 1) ํ‘œํ˜„๋œ ๋‚ด์šฉ์ด ๋ถ€์ •ํ™•ํ•˜๊ฑฐ๋‚˜ 2) ์ด step์„ undoํ•˜์ง€ ์•Š๋Š” ์ด์ƒ ๋งž๋Š” ๋‹ต๋ณ€์œผ๋กœ ๊ฐˆ ๊ฐ€๋Šฅ์„ฑ์ด ์—†๋Š” ๊ฒƒ

Result

image
  • final answer SFT๋งŒ ํ•ด๋„ ์„ฑ๋Šฅ์ด ๊ฐœ์„ ๋œ๋‹ค (3.1์˜ ๋งˆ์ง€๋ง‰ ํ–‰ SFT, Majority Voting 22.3 vs Few-shot+Final-Answre.. 23.5).
    • Few-shot + final-answer rl์€ 1~4 ํ† ํฐ ๋งŒํผ์˜ Supervision์„ ๊ฐ–์ง€๋งŒ SFT๋Š” Hunderes๋กœ ๊ฐ–๊ธฐ ๋•Œ๋ฌธ์— ๋‹ค๋ฅด๋‹ค๊ณ  ๋ถ„์„
image
  • ORM-superviesed reward models ~= PRM
    • ์œ„ ๊ทธ๋ฆผ์„ ๋ณด๋ฉด ORM์œผ๋กœ ํ•™์Šตํ•œ ๊ฒฐ๊ณผ๊ฐ€ PRM label ๊ฒฐ๊ณผ์™€ agreement๊ฐ€ ๋†’์€ ๊ฒƒ์„ ์•Œ ์ˆ˜ ์žˆ์Œ
    • ๋˜ํ•œ SFT, majority voting tracing error 11.4 vs SFT, ORM ranking 4.4๋ฅผ ๋น„๊ตํ–ˆ์„ ๋•Œ ORM๋งŒ์œผ๋กœ๋„ trace error๋ฅผ ๋งŽ์ด ์ค„์ผ ์ˆ˜ ์žˆ๋‹ค๊ณ  ํ•จ
    • ๋‹ค๋งŒ ์ด ๊ฒฐ๊ณผ๋Š” ์ด ๋„๋ฉ”์ธ์—์„œ๋งŒ ์ด๋Ÿด์ˆ˜๋„ ์ž‡์Œ
image
  • low trace error requires process-based feedback or reward model
    • Few-shot Final-answer RL,.. ๊ณผ SFT, Majority Voting ๋‘๊ฐœ์˜ ์ฐจ์ด๋Š” final answer๋Š” ๊ฑฐ์˜ ๋น„์Šทํ•˜์ง€๋งŒ trace error๊ฐ€ ๋งŽ์ด ์ฐจ์ด ๋‚จ (19.8 vs 11.4)
    • ๊ฐ™์€ ๊ฒฝํ–ฅ์„ฑ์€ Few-shot + Final Answer RL, ORM reranking 12.4 vs SFT, ORM / PRM reranking 4.4 - 3.4์—์„œ๋„ ์ผ์–ด๋‚จ
    • ํ•˜์ง€๋งŒ ์—ฌ๊ธฐ์— ORM-RL์„ ๋„ฃ์œผ๋ฉด few-shot + ORM RL, ORM reranking trace error๋„ 5.5๊นŒ์ง€ ๋–จ์–ด์ง
    • ์ฆ‰ process SFT๋ฅผ ํ•˜๋˜๊ฐ€ reward model์ด ํ•„์š”ํ•จ
image
  • RL ์€ Few-shot ์…‹ํŒ…์—์„œ๋Š” ์„ฑ๋Šฅ์„ ๋งŽ์ด ๊ฐœ์„ ํ–ˆ๊ณ  SFT์—์„œ๋Š” ์ ๋‹นํžˆ ๊ฐœ์„ ํ–ˆ๋”ฐ.
    • ํŠนํžˆ RM decoding + final answer rl์˜ ๊ฒฝ์šฐ ๊ฑฐ์˜ ์„ฑ๋Šฅ ๊ฐœ์„ ์ด ์—†๊ธฐ๋„ ํ–ˆ๋‹ค.