image

paper

TL;DR

  • I read this because.. : #213 ์—์„œ ์ž๊พธ ์–ธ๊ธ‰๋˜์–ด
  • task : test time scaling in LLM
  • problem : test-time scaling ๊ธฐ๋ฒ•์— ๋Œ€ํ•œ ๋ถ„์„.
  • architecture : PaLM-2 (340B) // llama 2 family (pretraining <-> test time ๋ณผ๋•Œ)
  • data : (PRM) PRM800K prompt๋ฅผ ๊ฐ€์ง€๊ณ  PaLM-2 + monte carlo roll-out์œผ๋กœ ์ƒˆ๋กœ ๋งŒ๋“ฆ
  • evaluation : MATH test split (500)
  • contribution :

Details

  • thumbnail image

test-time scale up

์ค‘์š”ํ•œ ๊ฒƒ์€ ํ•œ์ •๋œ “inference cost"๋‚ด์—์„œ ๊ฐ€์žฅ ํšจ๊ณผ์ ์œผ๋กœ ์“ธ ์ˆ˜ ์žˆ๋Š” ๋ฐฉ๋ฒ•์ž„. ํ•ด์„œ **“test-time compute-optimal scaling strategy”**๊ฐ€ ๋“ค์–ด๊ฐ. image

์ •ํ•ด์ง„ test-time compute ์ž์› $N$๋‚ด์—์„œ test-time hyper-param $\theta$๋ฅผ prompt $q$์— ๋Œ€ํ•ด์„œ ์ตœ์ ์„ ์ฐพ์•„์•ผ ํ•จ. ์ด๋Ÿฌํ•œ ์ตœ์ ์€ question์˜ ๋‚œ์ด๋„์— ๋”ฐ๋ผ ๋ฐฉ๋ฒ•์ด ๋‹ฌ๋ผ์ง„๋‹ค๋Š” ์ง๊ด€์ด ์žˆ์Œ. ๊ทธ๋ ‡๋‹ค๋ฉด ์ด ๋‚œ์ด๋„๋ฅผ ์–ด๋–ป๊ฒŒ ์ธก์ •ํ•  ๊ฒƒ์ธ๊ฐ€์— ๋Œ€ํ•ด์„œ๋Š” model์˜ 2048 ์ƒ˜ํ”Œ ์ค‘์— pass@1 rate๋ฅผ ๊ฐ€์ง€๊ณ  ๋‚œ์ด๋„๋ฅผ ์ธก์ •ํ•˜์—ฌ 5๊ฐœ์˜ bin์œผ๋กœ ๋‚˜๋ˆŒ ์ˆ˜ ์žˆ์Œ (– oracle difficulty) ๊ทธ๋Ÿฐ๋ฐ ์‹ค์ œ๋กœ infer ์ƒํ™ฉ์—์„œ๋Š” gt๋ฅผ ๋ชจ๋ฅด๊ธฐ ๋•Œ๋ฌธ์— final answer์— ๋Œ€ํ•œ learned verifier์˜ ์ ์ˆ˜์˜ ํ‰๊ท ์„ ๊ฐ€์ง€๊ณ  ์‹คํ–‰ํ•  ์ˆ˜ ์žˆ์Œ (–model-predicted difficulty) ์ด๋Ÿฐ ๋ฐฉ์‹์œผ๋กœ ๋‚œ์ด๋„๋ฅผ ๋‚˜๋ˆˆ ๋’ค์— ์ ํ•ฉํ•œ test-time scaling ๋ฐฉ๋ฒ•์œผ๋กœ ์ธก์ •ํ•ด์•ผ๋˜๊ธฐ ๋•Œ๋ฌธ์— ์ด๋Ÿฌํ•œ ๋ฐฉ๋ฒ• ์ž์ฒด๋„ ์ถ”๊ฐ€๋กœ ๋“œ๋Š” cost์ž„

scaling test-time compute with verifier

ORM๋„ ์จ๋ณด์•˜์ง€๋งŒ PRM์ด consistently outperformํ•ด์„œ PRM์„ ์ผ๋‹ค๊ณ  ํ•จ

training PRM

  • data
    • lightman et al.์ด ์ œ๊ณตํ•œ PRM800K๊ฐ€ ์žˆ๊ธด ํ•˜์ง€๋งŒ PalM 2๋ฅผ ํ•™์Šตํ•˜๋Š”๋ฐ GPT generated ์“ฐ๋Š”๊ฒŒ ineffectiveํ•˜๋‹ค๊ณ  ๊ด€์ฐฐ.
    • Math Shepherd์„ ๋”ฐ๋ผ monte carlo rollout์„ ๊ฐ€์ง€๊ณ  ๊ฐ ์Šคํ…์— ๋Œ€ํ•œ reward๋ฅผ ๊ตฌํ•˜๊ณ  ์ด๊ฑธ value๋กœ ์‚ฌ์šฉ.
    • ๋ฒ ์ด์Šค ๋ชจ๋ธ์— few-shot prompt๋ฅผ ์ฃผ์–ด์„œ ์งˆ๋ฌธ๋‹น 16๊ฐœ์˜ PRM์„ ์ƒ์„ฑ. 16๊ฐœ์˜ monte carlo rollout์„ ์‹œํ–‰ํ•˜๊ณ  parsableํ•œ answer๊ฐ€ ์•ˆ๋‚˜์˜ค๋Š”๊ฑด ์ง€์›Œ๋ฒ„๋ฆผ.
  • training
    • PRM์€ ์ด 0~1์‚ฌ์ด์˜ soft value๋ฅผ ์˜ˆ์ธกํ•˜๋Š” bce๋กœ ํ•™์Šต๋˜๋Š” binary classifier๊ฐ€ ์žˆ๋Š” ํ˜•ํƒœ.
    • val loss early stoppingํ–ˆ๋‹ค๊ณ  ์จ์žˆ์–ด์„œ ๋ช‡ ์—ํญํ•œ์ง€ ๋ชจ๋ฅด๊ฒ ์Œ
  • aggregation
    • step-wise: last๊ฐ€ ๊ฐ€์žฅ ์ข‹์•˜๋‹ค๊ณ  ํ•จ
    • intesr-answer: PRM ์„ verifier๋กœ ์‚ฌ์šฉํ•œ “best-of-N-weighted"๋กœ ์ผ๋‹ค๊ณ  ํ•จ.
  • search
    • BoN weighted
    • beam search: N beams; M beam width
    • lookahed search: beam search์™€ ๋‹ฌ๋ฆฌ N๊ฐœ์˜ beam์— ๋Œ€ํ•ด K step ์•ž์— ๊ฐ€๋ณด๊ณ  ๊ทธ step์˜ PRM value๋กœ beam serachํ•˜๋Š” ๊ฒƒ
      • stochastic์„ ์—†์• ๊ธฐ ์œ„ํ•ด temperature = 0
      • MCTS์—์„œ stochastic(exploration)ํ•œ ๊ฑธ ๋บ€ ๋ฒ„์ „์ด๋ผ๊ณ  ํ•˜๋ฉด ๋จ
image
  • result image

(left) ์ž‘์€ buget์ผ ๋• beam search » BoN. ๋†’์•„์งˆ ๋• BoN์ด ์ข‹๊ธฐ๋„ lookahead ๊ฐ€ ๋‹ค๋ฅธ ๋ฐฉ๋ฒ•์— ๋น„ํ•ด ๊ฐ™์€ cost๋Œ€๋น„ ๊ทธ๋ ‡๊ฒŒ ์ข‹์ง€ ์•Š์€๋ฐ simulating cost๊ฐ€ ํฌ๊ธฐ ๋•Œ๋ฌธ์— ๊ทธ๋Ÿฐ๋“ฏ. ๊ฐ€๋ น ๊ธธ๊ฒŒ ์ƒ์„ฑํ•˜๋”๋ผ๋„ ํ•œ๋‘์Šคํ…๋งŒ์— ๋๋‚ผ ์ˆ˜ ์žˆ๋Š” ๋ฌธ์ œ๋ฅผ ๊ณ„์† ํƒํ—˜ํ•˜๋Š” ๊ฒƒ์ด ๋ฐœ๊ฒฌ๋˜๊ณค ํ–ˆ์Œ

image

(right) ๋‚œ์ด๋„๊ฐ€ ์‰ฌ์šด ๊ฒฝ์šฐ์—๋Š” BoN์ด ์ข‹์•˜๊ณ  ๋†’์€ ๊ฒฝ์šฐ์—” beam search๊ฐ€ ์ข‹์•˜์Œ. – ์ด๋Š” ์ง๊ด€๊ณผ ๋งž๋Š”๋ฐ ์–ด๋ ค์šด ๋ฌธ์ œ๋Š” fisrt place์—์„œ ์ž˜ ๋‚˜์˜ค๊ธฐ ์–ด๋ ค์›Œ์„œ search๊ฐ€ ํ•„์š”ํ•˜๊ณ  ๋‚œ์ด๋„๊ฐ€ ์‰ฌ์šด ๊ฒฝ์šฐ์—๋Š” beam-search๊ฐ€ over optimizationํ•˜๋Š” ๊ฒฝํ–ฅ์ด ์žˆ์Œ. ๊ทธ๋ฆฌ๊ณ  ๊ฐ€์žฅ ์–ด๋ ค์šด ๋ฌธ์ œ๋Š” ๋‹ค ๊ฒฐ๊ณผ๊ฐ€ ์•ˆ์ข‹์•˜๋Š”๋ฐ (test-time scaling์ด ํšจ๊ณผ๊ฐ€ ์—†์—ˆ๋‹ค๋Š” ๋œป์ผ๋“ฏ) ์ด๋Š” ์–ด๋ ค์šด๋ฌธ์ œ์— ๋Œ€ํ•ด verifier๊ฐ€ ์ •ํ™•ํ•œ ํ•ด๊ฒฐ์„ ํ•˜์ง€ ๋ชปํ•˜๊ณ  ์˜คํžˆ๋ ค beam search๋ฅผ ํ†ตํ•ด spurious features๋ฅผ ๊ฐ•ํ™”ํ•˜๋Š” ๊ผด์ด ๋˜์–ด ์„ฑ๋Šฅ์ด ๋” ์•ˆ์ข‹์•„์ง„ ๊ฒƒ์œผ๋กœ ์ถ”์ •ํ•จ. – ํ …

image optimal ๋ฐฉ๋ฒ•์œผ๋กœ ํ•˜๋ฉด ์„ฑ๋Šฅ์ด ๋” ์ข‹์Œ

Refining the proposal distribution

sequential refinement ๋“ฑ ์ฒ˜๋Ÿผ sequentialํ•˜๊ฒŒ ์ƒ์„ฑํ•˜๊ฒŒ ํ•˜๋Š” ๊ฒƒ image

rescursive introspection (https://arxiv.org/abs/2407.18219 , RiSE)์™€ ๋น„์Šทํ•œ ์ ‘๊ทผ๋ฒ•์œผ๋กœ ํ•˜๋˜, ์ง๊ด€์ƒ “์ˆ˜์ •ํ•œ ์ •๋‹ต"์ด “ํ‹€๋ฆฐ ์˜ค๋‹ต"๊ฐ€ ๊ฐ€๊นŒ์šธ ๋•Œ, ์ด๋Ÿฐ refinement ํ•™์Šต์— ํšจ๊ณผ์ ์ผ ๊ฒƒ์ด๊ธฐ ๋•Œ๋ฌธ์— ๊ทธ๋Ÿฐ chr edit distance๋กœ ์˜ค๋‹ต์„ ๊ณจ๋ผ๋‚ด๋Š” ๊ณจ๋ผ๋‚ด๋Š” ์ž‘์—…์„ ํ–ˆ๊ณ , ์ž์›์ด ๋ถ€์กฑํ•ด์„œ ์›๋ž˜๋Š” on-policy multi-turn(=sequential)ํ•˜๊ฒŒ ๋ฝ‘์•„์•ผํ•˜๋Š”๋ฐ ๊ทธ๋ƒฅ ๋ณ‘๋ ฌ๋กœ ๋ฝ‘์€ ๋’ค์— ์ด์–ด๋ถ™์ด๋Š” ์‹์œผ๋กœ ํ–ˆ์Œ. ์ด๋•Œ ์˜ค๋‹ต ๊ฐœ์ˆ˜๋Š” 0~4๊ฐœ ์‚ฌ์ด์—์„œ ํ•˜๋‚˜ ์ƒ˜ํ”Œ๋งํ•ด์„œ ์‚ฌ์šฉํ–ˆ์Œ.

image

inferenceํ•  ๋•Œ ๋งž๋Š” ์ •๋‹ต์„ ๋‚ด๋ฑ‰์–ด๋„ ๋˜ ๊ณ ์น˜๊ณ  ์˜ค๋‹ต์œผ๋กœ ํ•˜๋Š” ๊ฒฝ์šฐ๊ฐ€ 38% ์žˆ์—ˆ์Œ. ์ด๋Ÿฐ ํ˜„์ƒ ๋•Œ๋ฌธ์— sequentialํ•˜๊ฒŒ ์ •๋‹ต์„ ์—ฌ๋Ÿฌ๊ฐœ ๋ฝ‘์œผ๋ฉด ์ด๊ฑธ๊ฐ€์ง€๊ณ  majority voting or verifier (์™ผ์ชฝ) sequence๊ธธ์ด๊ฐ€ ๊ธธ์–ด์งˆ ์ˆ˜๋ก pass@1์ด ๋Š˜์–ด๋‚˜๋Š” ํ˜„์ƒ (์˜ค๋ฅธ์ชฝ) parallel voting ๋ณด๋‹ค compute ์ž์›์ด ๋Š˜์–ด๋‚  ๋•Œ ์„ฑ๋Šฅ์ด ๋” ์ข‹์Œ

trade off sequential or parallel test-time compute

์ง๊ด€์€ sequential์€ ์‰ฌ์šด ๋ฌธ์ œ์— ๋Œ€ํ•ด์„œ ๋” ์ž˜ ๋  ๊ฒƒ ๊ฐ™๊ณ  (์™œ๋ƒ๋ฉด ์ฒ˜์Œ์— ๋ฐฉ๋ฒ•์—์„œ ์ˆ˜์ •ํ•˜๋Š” ๋ฐฉ์‹์ด๋‹ˆ) parallel์€ ์–ด๋ ค์šด ๋ฌธ์ œ์—์„œ ๋‹ค์–‘ํ•œ ์‹œ๋„๋ฅผ ํ•ด๋ด์•ผํ•˜๋‹ˆ ์–ด๋ ค์šด ๋ฌธ์ œ์— ๋Œ€ํ•ด์„œ ๋” ์ž˜ ๋  ๊ฒƒ ๊ฐ™์Œ. ์ฆ‰ ์ด๋ฅผ ๋‘˜๋‹ค ์“ฐ๋Š”๊ฒŒ ๊ฐ€์žฅ ์ข‹์„ ๊ฒƒ ๊ฐ™์Œ. image

(์˜ค๋ฅธ์ชฝ) ๋‚œ์ด๋„๊ฐ€ ๋‚ฎ์„ ๋•Œ๋Š” ๊ทธ๋ƒฅ sequential๋กœ ํ•˜๋Š”๊ฒŒ ์ข‹์•˜์ง€๋งŒ ๋†’์•„์งˆ ๋•Œ๋Š” ์ ์ •ํ•œ ๋น„์œจ์ด ์žˆ์—ˆ์Œ. (parallel์ด ๋ฌด์กฐ๊ฑด ๋˜ ์ข‹์€ ๊ฑด ์•„๋‹ˆ๋„ค)

์ด๊ฒƒ๋„ optimal๊ฐ’์ด ์žˆ์Œ image

tradeoff betweentes-time vs pretraining (์ž˜ ์ดํ•ด ๋ชปํ•จ)

image
  • ๋ณ„์ด ์ตœ๋Œ€ 14๋ฐฐ ๋งŽ์€ ํŒŒ๋ผ๋ฏธํ„ฐ๋กœ ํ•™์Šต๋œ ํ”„๋ฆฌํŠธ๋ ˆ์ด๋‹ ๋ชจ๋ธ.
  • ๊ฐ€๋กœ์ถ•์€ 6 * # parm * # tokens for pretraining (==max length?? ์ดํ•ด๋ฅผ ์ž˜ ๋ชปํ•จ) / 2 * N * total # of generated in inference time
  • ๋‚œ์ด๋„๊ฐ€ ๋†’์„์ˆ˜๋ก pretraining compute๋ฅผ ๋Š˜๋ฆฌ๋Š”๊ฒŒ ์ข‹๋‹ค๋Š” ๊ฒฐ๋ก