TL;DR
- I read this because.. : vision rl
- task : MLLM R1 replicate
- problem : MLLM R1 ํ์
- idea : ์ด์ฌํ GRPO ๋ฐ์ดํฐ ๋ชจ์์ ํ์
- input/output : {image, Q} -> reasoning, A
- architecture : InternVL2.5-7B-Instruct(r1 style), InternVL2.5-Pretrained-38B (r1-zero style)
- objective : RLOO loss
- baseline : SFT, CoT SFT(MAmmoTH-VL-8B), MPO(MMPR dataset)
- data : GeoQA-Plus, K12, CLEVR, Geometry3K, MATH, IconQA, M3CoT, DVQA, ScienceQA, ChartQA, AI2D, UniGeo, InfoVQA, GeoS, MapQA
- evaluation : MathVista, MathVerse, MathVision, Olympiad
- result : ํ๊ท ์ ์ผ๋ก ๊ฐ์ ๋ ์ํ ์ฑ๋ฅ. data scale์ด ๊ฐ์ฅ ์ ๊ฒ ๋ค์.
- contribution : ๋นจ๋ฆฌ ์ด์ฌํ ํ๋ค
- etc. :
Details
Dataset
- chart comprehension: ChartQA, DVQA, …
- General Scientific Reasoning: AI2D, ScienceQA, …
- Mathematical Reasoning: K12(proposed), GeoQA
training
reward format +
<think>...</think><answer>...</answer>parsing accuracy rewardloss advantage ๊ณ์ฐ์ RLOO
loss๋ PPO-clip loss
loss์ KL divergence term ์ถ๊ฐ ablation
- extra hparams
- rollout bs 128 / training bs 64 (8 rollout per sample)
- temperature 1
- loss term์ kl divergence ์ ์ธ
- format reward coefficient๋ instruction ์์ ์์ํ ๊ฒฝ์ฐ ์ ๋ฐ๋ฅด๊ธฐ ๋๋ฌธ์ 0.5 / pretrained weight์์ ์์ํ ๊ฒฝ์ฐ์ 1.0
key findings
- data filtering is crucial
InternVL2.5-8B-Instruct๋ก 8๋ฒ ์์ฑํ๊ฒ ํ ๋ค {0, 1} ์ ๊ฑฐ
ํ๊ณ ์ํ๊ณ ๋ ์ฐจ์ด๊ฐ ์ปธ๋ค.
- KL divergence
KL divergence๊ฐ ์์ ๋ length decrease ๊ฒฝํฅ์ด ์์๊ณ , ์ ํ๋๋ KL divergence ๋๊ณ ํค๊ณ ์ฐจ์ด๊ฐ ์์ด์ ๋๊ฒ ๋์๋ค
- Visual Aha Moment
evaluation
- K12
- ์คํ๊ต~๊ณ ๋ฑํ๊ต ์์ค์ 500๊ฐ์ fill-in-the-blank math question
- greedy decoding with a temperature 0
Result
- ํ์ต ๊ณผ์
- ์ผ๋จ MAmmoTH-VL-8B(https://mammoth-vl.github.io/ ) ๋ฅผ ์ ์ธํ๊ณ SFT๋ MPO๋ณด๋ค ์ฑ๋ฅ์ด ์ข์.
- training data scale๋ก ๋น๊ตํด๋ณด์์ ๋ SFT์๋ ํ์คํ ์ข๊ณ (SFT๋ ๋ค ํ๋ฝํจ) ๋ฐ์ดํฐ๋ฅผ ์กฐ๊ธ ๋ ์ด MPO ๋ณด๋ค math average ๊ฐ์ . ๋๋ถ๋ถ์ ๊ฐ์ ์ mathverse์ K12. olympiad๋ ๋์ง์์.
- ์ผ๋จ ๊ฐ ๋ฒค์น์ ๋ํด ํ๊ฐํด๋ณด์์ ๋ small model๊ณผ large model์ด ์ฑ๋ฅ์ด ์ฐจ์ด๊ฐ ๋ง์ด ๋๋ ๊ฒ์ Olympiad๊ฐ ๋๋ผ๋งํฑํ๋ค
- Mathvista๋ ํฐ scale์ด๋ small scale ๋๋ค mm-eureka์์ ์ข์ง ์์. ์์ธ์ง ๋ชจ๋ฅด๊ฒ ๋ค.
discussion
์๋ํ์ผ๋ ํจ๊ณผ๊ฐ ์์๋ ๊ฒ
- curriculum learning
- K12 ๋ฐ์ดํฐ์์ difficulty๋ฅผ ๋งค๊ธด ๋ค difficulty ์์ผ๋ก data sort๋ฅผ ํ๋ค.
- curriculum learning์ ํ๋ ์คํ๋ ค stable learning์ด ์๋๋ ๊ฒฝํฅ์ฑ์ด ๋ณด์๋ค.
- early~middle stage์์ ์ด๋ ค์ด ๋ฌธ์ ์ ๋ํ exploration์ ๋ชปํ๊ณ ๊ณ ์ฐฉํ๋๋๊ฒ ์๋๊ฐ? ์ถ์๋ค
- online data filtering
- difficulty {0,1}์ ์ ์ธํ๋ ๋ฐฉ์์ offline data filtering์ด๋ผ๊ณ ํ๊ณ PRIME๊ณผ ๋น์ทํ๊ฒ ํ๋ ๋ฐฉ์์ online data filtering์ด๋ผ๊ณ ํ ๋ ์ฑ๋ฅ ๊ฐ์
- online data filtering์ dynamically ๋ชจ๋ธ์ด ๊ฐ์ ๋จ์ ๋ฐ๋ผ ๋ค๋ฅธ ๋ฐ์ดํฐ๋ฅผ ๋ณผ ์ ์๋ค๊ณ ๊ธฐ๋ํ ์ ์๋ค
- ๊ทธ๋ฌ๋ online์ด ์ฑ๋ฅ ๊ฐ์ ์ด ๋ฏธ๋นํ๋๋ฐ ๊ฐ training round์์ batch size๊ฐ ๋ฌ๋ผ์ง๋ฉด์ gradient instability๊ฐ ์๊ฒจ์?๋ผ๊ณ ์๊ฐํ๋ค
- model size
- R1-zero ์๋๋ฆฌ์ค๋ฅผ small model์์ ์ฑ๊ณตํ๋ค๋ ์ฌ๋ก๋ค์ด ์์ง๋ง mm ์ํฉ์์๋ stability๊ฐ ๋์ง ์์๋ค