Image

paper , dataset

TL;DR

  • I read this because.. : 역시 많이 언급되어. PRM을 학습하기 위한 대표적인 방법 중 하나인듯.
  • task : math solving
  • problem : Process Reward Model 학습하고 싶은데 Human annotated 너무 비싸다
  • idea : MCTS를 사용하여 특정 step의 value를 구하고 그걸 PRM의 label로 사용하여 학습 – step-level PPO를 학습하자
  • architecture : LLaMA2-7B/13B/70B, LLemma-7B/34B, Mistral-7B, Deepseek-67B
  • objective : (PRM) bce loss (RL) PPO loss
  • baseline : (train/infer) ORM, Self-consistency, Self-consistency + ORM (data) rule-based, BART NLI
  • data : 170K solution for GSM8K / 270K for MATH
  • evaluation : GSM8K, MATH accuracy
  • result : 좋은 성능
  • contribution : 트위터를 보면 OAI 이후 첫 PRM paper라는듯? -> 이후 이걸 개선한게 OmeagPRM인듯?

Details

  • thumbnail

Image

  • PRM loss

Image

  • automatic process annotation

Image

저 value estimation을 MCTS로 했다고 생각하면 됨! 각 step별로 다 rollout한다고 생각하면 경우의 수가 너무 많아지니 이를 최적화한게 MCTS (https://gusals1620.tistory.com/3 )

Image

결론적으론 hard를 썼는데 hparam을 모델 별로 찾지 않아도 된다서라고? (mse로 해도 되는거 아닌감 ㅎ)

  • parameter setting

    • generator 와 completer는 metamath에 대해 3 epoch씩 학습 한 것
    • ORM / PRM 학습데이터를 생성하기 위해서 GSM8K와 MATH training data를 학습 -> 이후 문제당 15개의 solution을 생성
    • completer는 Llemma-7B를 사용하여 decoded number N=8로 생성 (completer와 generator는 어떻게 다른가.. generator는 solution을 만드는거고 completer는 rollout을 하는 주체인건가? 이 두 모델이 다를수가 있나?)
    • verification을 위해서는 LLaMA-2 70B와 Llemma-34B 사용
    • PPO학습의 policy 모델은 Llama2-7B와 Mistral-7B
    • 모델을 왜 이렇게 다양하게 쓴건지 잘 모르겠음
  • result

Image

256개 sample 중 verification 방법론 중 가장 좋음.

Image

다른 학습 방법론(ORM + PPO / RFT)와 비교했을 때 성능이 좋음

Image

Image

Image

Image

Image

  • (a)(b)를 보면 math-shepherd가 verifier / ORM보다 더 좋은 성능, model 둘다 커지면 성능도 좋아짐
  • (c) self-consistency와 비교했을 때, reward model이 generator 모델보다 너무 작으면 solution per problem이 커질수록 성능이 안좋아짐 – reward model도 generator만큼 좋은걸 써야 함
  • (d) verifier가 더 클 때 (a) 보다 훨씬 좋은 성능. SC와의 차이가 많이 커짐