TL;DR
- I read this because.. : #213 ์์ ์๊พธ ์ธ๊ธ๋์ด
- task : test time scaling in LLM
- problem : test-time scaling ๊ธฐ๋ฒ์ ๋ํ ๋ถ์.
- architecture : PaLM-2 (340B) // llama 2 family (pretraining <-> test time ๋ณผ๋)
- data : (PRM) PRM800K prompt๋ฅผ ๊ฐ์ง๊ณ PaLM-2 + monte carlo roll-out์ผ๋ก ์๋ก ๋ง๋ฆ
- evaluation : MATH test split (500)
- contribution :
Details
- thumbnail
test-time scale up
์ค์ํ ๊ฒ์ ํ์ ๋ “inference cost"๋ด์์ ๊ฐ์ฅ ํจ๊ณผ์ ์ผ๋ก ์ธ ์ ์๋ ๋ฐฉ๋ฒ์.
ํด์ **“test-time compute-optimal scaling strategy”**๊ฐ ๋ค์ด๊ฐ.
์ ํด์ง test-time compute ์์ $N$๋ด์์ test-time hyper-param $\theta$๋ฅผ prompt $q$์ ๋ํด์ ์ต์ ์ ์ฐพ์์ผ ํจ. ์ด๋ฌํ ์ต์ ์ question์ ๋์ด๋์ ๋ฐ๋ผ ๋ฐฉ๋ฒ์ด ๋ฌ๋ผ์ง๋ค๋ ์ง๊ด์ด ์์. ๊ทธ๋ ๋ค๋ฉด ์ด ๋์ด๋๋ฅผ ์ด๋ป๊ฒ ์ธก์ ํ ๊ฒ์ธ๊ฐ์ ๋ํด์๋ model์ 2048 ์ํ ์ค์ pass@1 rate๋ฅผ ๊ฐ์ง๊ณ ๋์ด๋๋ฅผ ์ธก์ ํ์ฌ 5๊ฐ์ bin์ผ๋ก ๋๋ ์ ์์ (– oracle difficulty) ๊ทธ๋ฐ๋ฐ ์ค์ ๋ก infer ์ํฉ์์๋ gt๋ฅผ ๋ชจ๋ฅด๊ธฐ ๋๋ฌธ์ final answer์ ๋ํ learned verifier์ ์ ์์ ํ๊ท ์ ๊ฐ์ง๊ณ ์คํํ ์ ์์ (–model-predicted difficulty) ์ด๋ฐ ๋ฐฉ์์ผ๋ก ๋์ด๋๋ฅผ ๋๋ ๋ค์ ์ ํฉํ test-time scaling ๋ฐฉ๋ฒ์ผ๋ก ์ธก์ ํด์ผ๋๊ธฐ ๋๋ฌธ์ ์ด๋ฌํ ๋ฐฉ๋ฒ ์์ฒด๋ ์ถ๊ฐ๋ก ๋๋ cost์
scaling test-time compute with verifier
ORM๋ ์จ๋ณด์์ง๋ง PRM์ด consistently outperformํด์ PRM์ ์ผ๋ค๊ณ ํจ
training PRM
- data
- lightman et al.์ด ์ ๊ณตํ PRM800K๊ฐ ์๊ธด ํ์ง๋ง PalM 2๋ฅผ ํ์ตํ๋๋ฐ GPT generated ์ฐ๋๊ฒ ineffectiveํ๋ค๊ณ ๊ด์ฐฐ.
- Math Shepherd์ ๋ฐ๋ผ monte carlo rollout์ ๊ฐ์ง๊ณ ๊ฐ ์คํ ์ ๋ํ reward๋ฅผ ๊ตฌํ๊ณ ์ด๊ฑธ value๋ก ์ฌ์ฉ.
- ๋ฒ ์ด์ค ๋ชจ๋ธ์ few-shot prompt๋ฅผ ์ฃผ์ด์ ์ง๋ฌธ๋น 16๊ฐ์ PRM์ ์์ฑ. 16๊ฐ์ monte carlo rollout์ ์ํํ๊ณ parsableํ answer๊ฐ ์๋์ค๋๊ฑด ์ง์๋ฒ๋ฆผ.
- training
- PRM์ ์ด 0~1์ฌ์ด์ soft value๋ฅผ ์์ธกํ๋ bce๋ก ํ์ต๋๋ binary classifier๊ฐ ์๋ ํํ.
- val loss early stoppingํ๋ค๊ณ ์จ์์ด์ ๋ช ์ํญํ์ง ๋ชจ๋ฅด๊ฒ ์
- aggregation
- step-wise: last๊ฐ ๊ฐ์ฅ ์ข์๋ค๊ณ ํจ
- intesr-answer: PRM ์ verifier๋ก ์ฌ์ฉํ “best-of-N-weighted"๋ก ์ผ๋ค๊ณ ํจ.
- search
- BoN weighted
- beam search: N beams; M beam width
- lookahed search: beam search์ ๋ฌ๋ฆฌ N๊ฐ์ beam์ ๋ํด K step ์์ ๊ฐ๋ณด๊ณ ๊ทธ step์ PRM value๋ก beam serachํ๋ ๊ฒ
- stochastic์ ์์ ๊ธฐ ์ํด temperature = 0
- MCTS์์ stochastic(exploration)ํ ๊ฑธ ๋บ ๋ฒ์ ์ด๋ผ๊ณ ํ๋ฉด ๋จ
- result
(left) ์์ buget์ผ ๋ beam search » BoN. ๋์์ง ๋ BoN์ด ์ข๊ธฐ๋ lookahead ๊ฐ ๋ค๋ฅธ ๋ฐฉ๋ฒ์ ๋นํด ๊ฐ์ cost๋๋น ๊ทธ๋ ๊ฒ ์ข์ง ์์๋ฐ simulating cost๊ฐ ํฌ๊ธฐ ๋๋ฌธ์ ๊ทธ๋ฐ๋ฏ. ๊ฐ๋ น ๊ธธ๊ฒ ์์ฑํ๋๋ผ๋ ํ๋์คํ ๋ง์ ๋๋ผ ์ ์๋ ๋ฌธ์ ๋ฅผ ๊ณ์ ํํํ๋ ๊ฒ์ด ๋ฐ๊ฒฌ๋๊ณค ํ์
(right) ๋์ด๋๊ฐ ์ฌ์ด ๊ฒฝ์ฐ์๋ BoN์ด ์ข์๊ณ ๋์ ๊ฒฝ์ฐ์ beam search๊ฐ ์ข์์. – ์ด๋ ์ง๊ด๊ณผ ๋ง๋๋ฐ ์ด๋ ค์ด ๋ฌธ์ ๋ fisrt place์์ ์ ๋์ค๊ธฐ ์ด๋ ค์์ search๊ฐ ํ์ํ๊ณ ๋์ด๋๊ฐ ์ฌ์ด ๊ฒฝ์ฐ์๋ beam-search๊ฐ over optimizationํ๋ ๊ฒฝํฅ์ด ์์. ๊ทธ๋ฆฌ๊ณ ๊ฐ์ฅ ์ด๋ ค์ด ๋ฌธ์ ๋ ๋ค ๊ฒฐ๊ณผ๊ฐ ์์ข์๋๋ฐ (test-time scaling์ด ํจ๊ณผ๊ฐ ์์๋ค๋ ๋ป์ผ๋ฏ) ์ด๋ ์ด๋ ค์ด๋ฌธ์ ์ ๋ํด verifier๊ฐ ์ ํํ ํด๊ฒฐ์ ํ์ง ๋ชปํ๊ณ ์คํ๋ ค beam search๋ฅผ ํตํด spurious features๋ฅผ ๊ฐํํ๋ ๊ผด์ด ๋์ด ์ฑ๋ฅ์ด ๋ ์์ข์์ง ๊ฒ์ผ๋ก ์ถ์ ํจ. – ํ …
Refining the proposal distribution
sequential refinement ๋ฑ ์ฒ๋ผ sequentialํ๊ฒ ์์ฑํ๊ฒ ํ๋ ๊ฒ
rescursive introspection (https://arxiv.org/abs/2407.18219 , RiSE)์ ๋น์ทํ ์ ๊ทผ๋ฒ์ผ๋ก ํ๋, ์ง๊ด์ “์์ ํ ์ ๋ต"์ด “ํ๋ฆฐ ์ค๋ต"๊ฐ ๊ฐ๊น์ธ ๋, ์ด๋ฐ refinement ํ์ต์ ํจ๊ณผ์ ์ผ ๊ฒ์ด๊ธฐ ๋๋ฌธ์ ๊ทธ๋ฐ chr edit distance๋ก ์ค๋ต์ ๊ณจ๋ผ๋ด๋ ๊ณจ๋ผ๋ด๋ ์์ ์ ํ๊ณ , ์์์ด ๋ถ์กฑํด์ ์๋๋ on-policy multi-turn(=sequential)ํ๊ฒ ๋ฝ์์ผํ๋๋ฐ ๊ทธ๋ฅ ๋ณ๋ ฌ๋ก ๋ฝ์ ๋ค์ ์ด์ด๋ถ์ด๋ ์์ผ๋ก ํ์. ์ด๋ ์ค๋ต ๊ฐ์๋ 0~4๊ฐ ์ฌ์ด์์ ํ๋ ์ํ๋งํด์ ์ฌ์ฉํ์.
inferenceํ ๋ ๋ง๋ ์ ๋ต์ ๋ด๋ฑ์ด๋ ๋ ๊ณ ์น๊ณ ์ค๋ต์ผ๋ก ํ๋ ๊ฒฝ์ฐ๊ฐ 38% ์์์. ์ด๋ฐ ํ์ ๋๋ฌธ์ sequentialํ๊ฒ ์ ๋ต์ ์ฌ๋ฌ๊ฐ ๋ฝ์ผ๋ฉด ์ด๊ฑธ๊ฐ์ง๊ณ majority voting or verifier (์ผ์ชฝ) sequence๊ธธ์ด๊ฐ ๊ธธ์ด์ง ์๋ก pass@1์ด ๋์ด๋๋ ํ์ (์ค๋ฅธ์ชฝ) parallel voting ๋ณด๋ค compute ์์์ด ๋์ด๋ ๋ ์ฑ๋ฅ์ด ๋ ์ข์
trade off sequential or parallel test-time compute
์ง๊ด์ sequential์ ์ฌ์ด ๋ฌธ์ ์ ๋ํด์ ๋ ์ ๋ ๊ฒ ๊ฐ๊ณ (์๋๋ฉด ์ฒ์์ ๋ฐฉ๋ฒ์์ ์์ ํ๋ ๋ฐฉ์์ด๋) parallel์ ์ด๋ ค์ด ๋ฌธ์ ์์ ๋ค์ํ ์๋๋ฅผ ํด๋ด์ผํ๋ ์ด๋ ค์ด ๋ฌธ์ ์ ๋ํด์ ๋ ์ ๋ ๊ฒ ๊ฐ์. ์ฆ ์ด๋ฅผ ๋๋ค ์ฐ๋๊ฒ ๊ฐ์ฅ ์ข์ ๊ฒ ๊ฐ์.
(์ค๋ฅธ์ชฝ) ๋์ด๋๊ฐ ๋ฎ์ ๋๋ ๊ทธ๋ฅ sequential๋ก ํ๋๊ฒ ์ข์์ง๋ง ๋์์ง ๋๋ ์ ์ ํ ๋น์จ์ด ์์์. (parallel์ด ๋ฌด์กฐ๊ฑด ๋ ์ข์ ๊ฑด ์๋๋ค)
์ด๊ฒ๋ optimal๊ฐ์ด ์์
tradeoff betweentes-time vs pretraining (์ ์ดํด ๋ชปํจ)
- ๋ณ์ด ์ต๋ 14๋ฐฐ ๋ง์ ํ๋ผ๋ฏธํฐ๋ก ํ์ต๋ ํ๋ฆฌํธ๋ ์ด๋ ๋ชจ๋ธ.
- ๊ฐ๋ก์ถ์ 6 * # parm * # tokens for pretraining (==max length?? ์ดํด๋ฅผ ์ ๋ชปํจ) / 2 * N * total # of generated in inference time
- ๋์ด๋๊ฐ ๋์์๋ก pretraining compute๋ฅผ ๋๋ฆฌ๋๊ฒ ์ข๋ค๋ ๊ฒฐ๋ก