image

paper

TL;DR

  • I read this because.. : o1 It was mentioned in a video about it, so I thought I’d check it out.
  • task : math
  • problem : outcome-based vs process-based
  • idea: Use GSM8K to 1) learn only the final answer 2) learn by adding human correctness annotations for each reasoning step to the human generated reasoning trace or model generated.
  • architecture : ours base-70B (like it’s a secret)
  • objective : ce (SFT) / bce (ORM / PRM)
  • baseline : PaLM-540B, Minerva-540B, GPT-J-6B, Codex-175B, InstructGPT-175B, GPT-175B
  • data : GSM8K -> (eval) GSM8K-test, MATH
  • evaluation : final-answer error rate / trace error rate(human annotated). MATH dataset (OOD error rate)
  • result : 1) outcome-, process- both have similar final answer error rate. 2) both process- and outcome- RMs can generate process-based feedback 3) process-based feedback or reward model is needed to reduce trace error
  • contribution : Lots of analysis. Not sure how important the paper actually is

Details

image

training: overview

  • step: new-line seperated (one line is one step)
  • answer: last line
  • policy network: “each step” as an action, “all the tokens so far” as an observation
    • train with few-shot prompt, SFT, RL
  • reward model -> used for reranking

SFT

  • Learn to reasoning trace.
  • Learn until val loss rises. Approximate width of 2

Reward model

  • ORM: Similar to #209, learn with binary labels for whether the final answer is correct or incorrect
  • PRM: Learn from binary labels if the step so far is correct
  • These are labeled as human annotated.
  • Both are trained using samples from the current policy model. (temperature 1.0. K=96).
  • Starting from SFT for ORM, starting from pretrained lm for few-shot.
  • For PRM, annotated with 3 samples per question from the SFT policy network.
  • The problem is centered on kids with incorrect SFT predictions.
  • For PRM, initialize with ORM model and select optimal val loss before 2000 steps because val loss fluctuates a bit

Decoding

  • Take 96 samples and apply some decoding techniques
  • self-consistency
  • RM weighted decoding (==verifier voting) – voting weighted by RM score
  • Slightly better than the highest RM score

RL via Expert Iteration

image

I haven’t read it, so I’m not sure, but it seems like you’re talking about using a kid trained with RL as a policy to pull traces and repeat it.

  • SFT vs few-shot based
  • initial policy network is SFT or base LM with 5-shot prompt, or optional
image
  • Policy Improvement
    • final-answer RL(a.k.a. self-taught reasoner)
  • Draw K samples per question and filter by accuracy of final-answer
  • For SFT, select only one per question (no reason)
    • ORM-RL
  • Select the app that the ORM scored highest out of K traces
    • PRM-RL
  • Draw K (=96) candidate steps and choose the one with the highest score in PRM. Exit if final answer or after 15 steps
  • For a few-shot base, RM is retrained each time, and for SFT, RM is fixed.

Data annotation

  • For stepwise labels, it asks to find the first incorrect step in the generated model. This criterion is based on 1) the representation is incorrect, or 2) there is no possibility to get to the correct answer without undoing this step.

Result

image
  • The final answer SFT alone improves performance (last row of 3.1, SFT, Majority Voting 22.3 vs Few-shot+Final-Answre.. 23.5).
  • Analyzed that Few-shot + final-answer rl is different because it has 1-4 tokens worth of supervision, while SFT has Hunderes.
image
  • ORM-superviesed reward models ~= PRM
  • In the figure above, we can see that the results trained with ORM have a high agreement with the PRM label results
  • Also, when comparing SFT, majority voting tracing error 11.4 vs SFT, ORM ranking 4.4, it was found that ORM alone can reduce trace errors significantly.
  • However, this result may only be true for this domain.
image
  • low trace error requires process-based feedback or reward model
  • The difference between Few-shot Final-answer RL,... ' and SFT, Majority Voting`, the final answer is almost the same, but the trace error is much different (19.8 vs 11.4).
  • The same trend occurs for Few-shot + Final Answer RL, ORM reranking 12.4 vs SFT, ORM / PRM reranking 4.4 - 3.4
  • But if we put ORM-RL here, the few-shot + ORM RL, ORM reranking trace error also drops to 5.5
  • This means you either need a process SFT or a reward model
image
  • RL improves performance significantly in the Few-shot setting and moderately in SFT.
  • In particular, for RM decoding + final answer rl, there was almost no performance improvement.