TL;DR
- I read this because.. : o1 ๊ด๋ จ ์์์์ ์ธ๊ธ๋์ด
- task : math
- problem : outcome-based vs process-based
- idea : GSM8K๋ฅผ ์ฌ์ฉํ์ฌ 1) final answer๋ง ํ์ต 2) human generated reasoning trace or model generated์ ๊ฐ reasoning step์ ๋ํ human correctness annotation์ ์ถ๊ฐํ์ฌ ํ์ต.
- architecture : ours base-70B (๋น๋ฐ์ธ๋ฏ)
- objective : ce (SFT) / bce (ORM / PRM)
- baseline : PaLM-540B, Minerva-540B, GPT-J-6B, Codex-175B, InstructGPT-175B, GPT-175B
- data : GSM8K -> (eval) GSM8K-test, MATH
- evaluation : final-answer error rate / trace error rate(human annotated). MATH dataset (OOD error rate)
- result : 1) outcome-, process- ๋ชจ๋ final answer error rate๋ ๋น์ท. 2) process-, outcome- RM ์ ์ฌ์ฉํ ๊ฒฝ์ฐ ๋ ๋ค process-based feedback ์์ฑ ๊ฐ๋ฅ 3) trace error๋ฅผ ์ค์ด๊ธฐ ์ํด์๋ process-based feedback or reward model์ด ํ์
- contribution : ๋ค์ํ ๋ถ์. ์ฌ์ค ์ผ๋ง๋ ์ค์ํ ๋ ผ๋ฌธ์ธ์ง๋ ๋ชจ๋ฅด๊ฒ ์
Details
training: overview
- step: new-line seperated (ํ ์ค์ด ํ step)
- answer: last line
- policy network: “each step” as an action, “all the tokens so far” as an observation
- train with few-shot prompt, SFT, RL
- reward model -> rerankingํ๋๋ฐ ์ฌ์ฉ๋จ
SFT
- reasoning trace๊น์ง ํ์ต.
- val loss ๊ฐ ์์นํ ๋ ๊น์ง ํ์ต. ๋๋ต 2์ํญ
Reward model
- ORM: #209 ์ ๋น์ทํ๊ฒ final answer๊ฐ ๋ง๋์ง ํ๋ฆฐ์ง์ ๋ํด binary label๋ก ํ์ต
- PRM: ์ง๊ธ๊น์ง์ step์ด ๋ง๋์ง binary label๋ก ํ์ต
- ์ด์ ๋ํ label์ human annotated๋ก ๋ฐ์.
- ๋ ๊ฐ ๋ชจ๋ ํ์ฌ policy ๋ชจ๋ธ์์ ๋์จ sample์ ์ฌ์ฉํ์ฌ ํ์ต. (temperature 1.0. K=96).
- ORM์ ๊ฒฝ์ฐ SFT์์ ์์, few-shot์ ๊ฒฝ์ฐ pretrained lm์์ ์์.
- PRM์ ๊ฒฝ์ฐ SFT policy network์์ ๋ฌธ์ ๋น 3๊ฐ์ ์ํ๋ก ์ด๋ ธํ ์ด์ ๋ฐ์.
- ์ด๋ ๋ฌธ์ ๋ SFT ์์ธก์ด ํ๋ฆฐ ์ ๋ค์ ์์ฃผ๋กํจ.
- PRM์ ๊ฒฝ์ฐ ORM ๋ชจ๋ธ๋ก ์ด๊ธฐํํ๊ณ val loss๊ฐ ๋ณ๋์ด ์ข ์์ด์ 2000 step ์ด์ ์ต์ Val Loss๋ก ์ ์
Decoding
- 96 samples ๋ฝ๊ณ ๋ช๊ฐ์ง ๋์ฝ๋ฉ ๊ธฐ๋ฒ ์ ์ฉ
- self-consistency
- RM weighted decoding (==verifier voting) – RM score๋งํผ weighted ํด์ voting
- highest RM score๋ก ์ ์ ํ๋ ๊ฒ๋ณด๋ค ์ฝ๊ฐ ์ข์์
RL via Expert Iteration
์ ์ฝ์ด์ ์ ๋ชจ๋ฅด๊ฒ ์ผ๋ RL๋ก ํ์ต๋ ์ ๋ฅผ Policy๋ก ์จ์ trace๋ค ๋ฝ๊ณ ์ด๋ฅผ ๋ฐ๋ณตํ๋๊ฑธ ๋งํ๋๋ฏ
- SFT vs few-shot based
- initial policy network๋ SFT์ด๊ฑฐ๋ 5-shot prompt๋ฅผ ํ base LM ์ด๊ฑฐ๋ ์ ํ์ ์ฌ์ง๊ฐ ์์
- Policy Improvement
- final-answer RL(a.k.a. self-taught reasoner)
- ๋ฌธ์ ๋น K๊ฐ์ ์ํ์ ๋ฝ๊ณ final-answer์ ์ ํ๋๋ก ํํฐ๋ง
- SFT์ ๊ฒฝ์ฐ ๋ฌธ์ ๋น ํ๋๋ง ์ ํ (์ด์ ๋ ์์)
- ORM-RL
- K๊ฐ์ traces ์ค ORM์ด ๊ฐ์ฅ ๋๊ฒ ์ ์๋ฅผ ๋งค๊ธด์ ๋ฅผ ์ ํ
- PRM-RL
- K(=96)๊ฐ์ candidate step์ ๋ฝ๊ณ PRM์์ ๊ฐ์ฅ ๋์ ์ ์๋ฅผ ๊ฐ์ง ์ ๋ฅผ ์ ํ. final answer์ด๊ฑฐ๋ 15 ์คํ ์ด ๋์ด๊ฐ๋ฉด ์ข ๋ฃ
- few-shot base์ผ ๊ฒฝ์ฐ RM์ ๋งค๋ฒ ์๋ก ํ์ตํ๊ณ SFT์ ๊ฒฝ์ฐ RM์ ๊ณ ์ ํจ
- final-answer RL(a.k.a. self-taught reasoner)
Data annotation
- stepwise label์ ๊ฒฝ์ฐ ์์ฑ๋ ๋ชจ๋ธ์์ ์ฒซ๋ฒ์งธ๋ก ํ๋ฆฐ step์ ์ฐพ์ผ๋ผ๊ณ ํจ. ์ด ๊ธฐ์ค์ 1) ํํ๋ ๋ด์ฉ์ด ๋ถ์ ํํ๊ฑฐ๋ 2) ์ด step์ undoํ์ง ์๋ ์ด์ ๋ง๋ ๋ต๋ณ์ผ๋ก ๊ฐ ๊ฐ๋ฅ์ฑ์ด ์๋ ๊ฒ
Result
- final answer SFT๋ง ํด๋ ์ฑ๋ฅ์ด ๊ฐ์ ๋๋ค (3.1์ ๋ง์ง๋ง ํ
SFT, Majority Voting22.3 vsFew-shot+Final-Answre..23.5).- Few-shot + final-answer rl์ 1~4 ํ ํฐ ๋งํผ์ Supervision์ ๊ฐ์ง๋ง SFT๋ Hunderes๋ก ๊ฐ๊ธฐ ๋๋ฌธ์ ๋ค๋ฅด๋ค๊ณ ๋ถ์
- ORM-superviesed reward models ~= PRM
- ์ ๊ทธ๋ฆผ์ ๋ณด๋ฉด ORM์ผ๋ก ํ์ตํ ๊ฒฐ๊ณผ๊ฐ PRM label ๊ฒฐ๊ณผ์ agreement๊ฐ ๋์ ๊ฒ์ ์ ์ ์์
- ๋ํ
SFT, majority votingtracing error 11.4 vsSFT, ORM ranking4.4๋ฅผ ๋น๊ตํ์ ๋ ORM๋ง์ผ๋ก๋ trace error๋ฅผ ๋ง์ด ์ค์ผ ์ ์๋ค๊ณ ํจ - ๋ค๋ง ์ด ๊ฒฐ๊ณผ๋ ์ด ๋๋ฉ์ธ์์๋ง ์ด๋ด์๋ ์์
- low trace error requires process-based feedback or reward model
Few-shot Final-answer RL,..๊ณผSFT, Majority Voting๋๊ฐ์ ์ฐจ์ด๋ final answer๋ ๊ฑฐ์ ๋น์ทํ์ง๋ง trace error๊ฐ ๋ง์ด ์ฐจ์ด ๋จ (19.8 vs 11.4)- ๊ฐ์ ๊ฒฝํฅ์ฑ์
Few-shot + Final Answer RL, ORM reranking12.4 vs SFT, ORM / PRM reranking 4.4 - 3.4์์๋ ์ผ์ด๋จ - ํ์ง๋ง ์ฌ๊ธฐ์
ORM-RL์ ๋ฃ์ผ๋ฉดfew-shot + ORM RL, ORM rerankingtrace error๋ 5.5๊น์ง ๋จ์ด์ง - ์ฆ process SFT๋ฅผ ํ๋๊ฐ reward model์ด ํ์ํจ
- RL ์ Few-shot ์
ํ
์์๋ ์ฑ๋ฅ์ ๋ง์ด ๊ฐ์ ํ๊ณ SFT์์๋ ์ ๋นํ ๊ฐ์ ํ๋ฐ.
- ํนํ RM decoding + final answer rl์ ๊ฒฝ์ฐ ๊ฑฐ์ ์ฑ๋ฅ ๊ฐ์ ์ด ์๊ธฐ๋ ํ๋ค.