TL;DR
- I read this because.. : Is it bad to do too much SFT?2
- task : reasoning model
- problem : As SFT progresses, pass@1 improves but pass@k tends to worsen
- idea : weight ensemble pretrained and SFT
- input/output : prompt -> {reasoning, answer}
- architecture : {Gemma-2-2B, Qwen-2.5-0.5B}
- objective : ce loss, GRPO loss
- baseline : SFT, temperature majority voting
- data : SFT {GSM8K, OpenThoughts-114k (cold-start SFT)} -> GRPO {30K subset of rephrased question from MetaMath}
- evaluation : AIME24, MATH500, GSM8K / majority voting, BoN
- Result :** Diversity decreases as SFT progresses. The upper bound of RL performance drops as we do more SFT. Wise-FT is best and this performance is better than BoN with temperature diversification.
- contribution : Various analytics
- etc. : 2B, 0.5B is said to be the limit
Details
- related work
- PRESERVING DIVERSITY IN SUPERVISED FINE-TUNING OF LARGE LANGUAGE MODELS
- Inference-Aware Fine-Tuning for Best-of-N Sampling in Large Language Models
- pass@1 vs pass@k tradeoff
better test time scaling / RL scaling
diversity collapse
Percentage of unique answers in AIME2024 as SFT progresses
- PPO further training performance for different SFT step ckpts
- Policy diversity breaks down without KL regularization
- This does not mean that adding KL regularization can converge to a policy that is better than the existing diversity ==> Proof in appendix
pass@k has upper bounds on bias and variance according to jensen’s inequality.
SFT increases the pass@1 variance. (Wrong is always wrong and right is always right) (==direction of decreasing response diversity.)