TL;DR
- I read this because.. : mentioned. I was wondering how to divide the steps.
- task : LLM in reasoning
- problem : DPO is widely used, but performance gains in long-context are limited
- idea : DPO loss to maximize win/loss steps for incorrect steps if there is long reasoning
- architecture : Qwen2, Qwen1.5 Meta-Llama-3-70B, deepseek-math-7b-base
- objective : step-DPO (proposed)
- baseline : SFT, DPO
- data : 374K pair data (proposed) with the first incorrect step stored, AQuA
- evaluation : MATH, GSM8K, AIME, Odyssey-MATH
- result : Better performance than DPO. Said to have beaten GPT-4-1106, Claude-3-Opus, and Gemini-1.5-Pro.
- contribution : data disclosure. There seem to be a lot of these, but I don’t know if this is the first one
- etc. :
Details
Performance
motivation
The disadvantage of SFT in this paper is that it increases the likelihood of undesirable outputs as well as desirable outputs -> prone to hallucination To solve this problem, RLHF provides undesriable supervision, but DPO is said to be ineffective for long sequence output. (due to lack of fine-grained process supervision)
Step-DPO
Maximize the win – lose margin for the wrong step, not the entire sequence.
- $s_i$ : i-th reasoning step
- $x$ : prompt
- $k$ : first incorrect step
In-distribtuion data construction
The goal is to create something like this
Pipelines
error collection A collection of problems x and gt answer $\hat{y}$. Take reference model $\pi_{ref}$ and run it with step-wise CoT preifx and divide by step A collection of things where final answer y is different from gt answer.
step localization Find the first incorrect $k$ in the reasoning step $y=s_1, s_2, … , s_n$. (manually or via gpt-4) Select $s_{lose}$ as the error of the wrong step k
rectification Given a suitable ressoning step $s_{1~{k-1}}$, infer multiple times to the reference model to get multiple
Of these, the one whose final answer matches gt is selected as $s_{win}$. Even if the answer is correct, the process can be incorrect, which is refined manually or with gpt-4 (not shown)
Result
- Collected 374K in total, of which 299K was SFT data and 75K was Step-DPO
- SFT is a 3 or 2 epoxy
- Step-DPO is 8 or 4 epoxy turns
- Using the AQuA dataset in addition to the SFT dataset
Ablation
DPO vs Step-DPO
in-distribution vs out-distribution
It’s important that the data we use is the result of inference from the model we trained.