Image

paper , code , dataset

TL;DR

  • I read this because.. : vision rl
  • task : MLLM R1 replicate
  • problem : MLLM R1 ํ•˜์ž
  • idea : ์—ด์‹ฌํžˆ GRPO ๋ฐ์ดํ„ฐ ๋ชจ์•„์„œ ํ•˜์ž
  • input/output : {image, Q} -> reasoning, A
  • architecture : InternVL2.5-7B-Instruct(r1 style), InternVL2.5-Pretrained-38B (r1-zero style)
  • objective : RLOO loss
  • baseline : SFT, CoT SFT(MAmmoTH-VL-8B), MPO(MMPR dataset)
  • data : GeoQA-Plus, K12, CLEVR, Geometry3K, MATH, IconQA, M3CoT, DVQA, ScienceQA, ChartQA, AI2D, UniGeo, InfoVQA, GeoS, MapQA
  • evaluation : MathVista, MathVerse, MathVision, Olympiad
  • result : ํ‰๊ท ์ ์œผ๋กœ ๊ฐœ์„ ๋œ ์ˆ˜ํ•™ ์„ฑ๋Šฅ. data scale์ด ๊ฐ€์žฅ ์ ๊ฒŒ ๋“ค์Œ.
  • contribution : ๋นจ๋ฆฌ ์—ด์‹ฌํžˆ ํ–ˆ๋‹ค
  • etc. :

Details

Dataset

Image

  • chart comprehension: ChartQA, DVQA, …
  • General Scientific Reasoning: AI2D, ScienceQA, …
  • Mathematical Reasoning: K12(proposed), GeoQA

training

  • reward format + <think>...</think><answer>...</answer> parsing accuracy reward

  • loss advantage ๊ณ„์‚ฐ์€ RLOO

Image

loss๋Š” PPO-clip loss

Image

loss์— KL divergence term ์ถ”๊ฐ€ ablation

Image

  • extra hparams
    • rollout bs 128 / training bs 64 (8 rollout per sample)
    • temperature 1
    • loss term์— kl divergence ์ œ์™ธ
    • format reward coefficient๋Š” instruction ์—์„œ ์‹œ์ž‘ํ•œ ๊ฒฝ์šฐ ์ž˜ ๋”ฐ๋ฅด๊ธฐ ๋•Œ๋ฌธ์— 0.5 / pretrained weight์—์„œ ์‹œ์ž‘ํ•œ ๊ฒฝ์šฐ์—” 1.0
    • Image

key findings

  • data filtering is crucial InternVL2.5-8B-Instruct๋กœ 8๋ฒˆ ์ƒ์„ฑํ•˜๊ฒŒ ํ•œ ๋’ค {0, 1} ์ œ๊ฑฐ Image

ํ•˜๊ณ  ์•ˆํ•˜๊ณ ๋Š” ์ฐจ์ด๊ฐ€ ์ปธ๋‹ค.

  • KL divergence Image

KL divergence๊ฐ€ ์žˆ์„ ๋•Œ length decrease ๊ฒฝํ–ฅ์ด ์žˆ์—ˆ๊ณ , ์ •ํ™•๋„๋„ KL divergence ๋„๊ณ  ํ‚ค๊ณ  ์ฐจ์ด๊ฐ€ ์žˆ์–ด์„œ ๋„๊ฒŒ ๋˜์—ˆ๋‹ค

  • Visual Aha Moment Image

evaluation

  • K12
    • ์ค‘ํ•™๊ต~๊ณ ๋“ฑํ•™๊ต ์ˆ˜์ค€์˜ 500๊ฐœ์˜ fill-in-the-blank math question
    • greedy decoding with a temperature 0

Result

  • ํ•™์Šต ๊ณผ์ • Image
Image
  • ์ผ๋‹จ MAmmoTH-VL-8B(https://mammoth-vl.github.io/ ) ๋ฅผ ์ œ์™ธํ•˜๊ณ  SFT๋‚˜ MPO๋ณด๋‹ค ์„ฑ๋Šฅ์ด ์ข‹์Œ.
  • training data scale๋กœ ๋น„๊ตํ•ด๋ณด์•˜์„ ๋•Œ SFT์™€๋Š” ํ™•์‹คํžˆ ์ข‹๊ณ  (SFT๋Š” ๋‹ค ํ•˜๋ฝํ•จ) ๋ฐ์ดํ„ฐ๋ฅผ ์กฐ๊ธˆ ๋” ์“ด MPO ๋ณด๋‹ค math average ๊ฐœ์„ . ๋Œ€๋ถ€๋ถ„์˜ ๊ฐœ์„ ์€ mathverse์™€ K12. olympiad๋Š” ๋†’์ง€์•Š์Œ.
Image
  • ์ผ๋‹จ ๊ฐ ๋ฒค์น˜์— ๋Œ€ํ•ด ํ‰๊ฐ€ํ•ด๋ณด์•˜์„ ๋•Œ small model๊ณผ large model์ด ์„ฑ๋Šฅ์ด ์ฐจ์ด๊ฐ€ ๋งŽ์ด ๋‚˜๋Š” ๊ฒƒ์€ Olympiad๊ฐ€ ๋“œ๋ผ๋งˆํ‹ฑํ•˜๋‹ค
  • Mathvista๋Š” ํฐ scale์ด๋‚˜ small scale ๋‘˜๋‹ค mm-eureka์—์„œ ์ข‹์ง€ ์•Š์Œ. ์™œ์ธ์ง„ ๋ชจ๋ฅด๊ฒ ๋‹ค.

discussion

์‹œ๋„ํ–ˆ์œผ๋‚˜ ํšจ๊ณผ๊ฐ€ ์—†์—ˆ๋˜ ๊ฒƒ

  • curriculum learning
    • K12 ๋ฐ์ดํ„ฐ์—์„œ difficulty๋ฅผ ๋งค๊ธด ๋’ค difficulty ์ˆœ์œผ๋กœ data sort๋ฅผ ํ–ˆ๋‹ค.
    • Image
    • curriculum learning์„ ํ•˜๋‹ˆ ์˜คํžˆ๋ ค stable learning์ด ์•ˆ๋˜๋Š” ๊ฒฝํ–ฅ์„ฑ์ด ๋ณด์˜€๋‹ค.
    • early~middle stage์—์„œ ์–ด๋ ค์šด ๋ฌธ์ œ์— ๋Œ€ํ•œ exploration์„ ๋ชปํ•˜๊ณ  ๊ณ ์ฐฉํ™”๋˜๋Š”๊ฒƒ ์•„๋‹Œ๊ฐ€? ์‹ถ์—ˆ๋‹ค
  • online data filtering
    • Image
    • difficulty {0,1}์„ ์ œ์™ธํ•˜๋Š” ๋ฐฉ์‹์„ offline data filtering์ด๋ผ๊ณ  ํ•˜๊ณ  PRIME๊ณผ ๋น„์Šทํ•˜๊ฒŒ ํ•˜๋Š” ๋ฐฉ์‹์„ online data filtering์ด๋ผ๊ณ  ํ•  ๋•Œ ์„ฑ๋Šฅ ๊ฐœ์„ 
    • online data filtering์€ dynamically ๋ชจ๋ธ์ด ๊ฐœ์„ ๋จ์— ๋”ฐ๋ผ ๋‹ค๋ฅธ ๋ฐ์ดํ„ฐ๋ฅผ ๋ณผ ์ˆ˜ ์žˆ๋‹ค๊ณ  ๊ธฐ๋Œ€ํ•  ์ˆ˜ ์žˆ๋‹ค
    • Image
    • ๊ทธ๋Ÿฌ๋‚˜ online์ด ์„ฑ๋Šฅ ๊ฐœ์„ ์ด ๋ฏธ๋น„ํ–ˆ๋Š”๋ฐ ๊ฐ training round์—์„œ batch size๊ฐ€ ๋‹ฌ๋ผ์ง€๋ฉด์„œ gradient instability๊ฐ€ ์ƒ๊ฒจ์„œ?๋ผ๊ณ  ์ƒ๊ฐํ–ˆ๋‹ค
  • model size
    • R1-zero ์‹œ๋‚˜๋ฆฌ์˜ค๋ฅผ small model์—์„œ ์„ฑ๊ณตํ–ˆ๋‹ค๋Š” ์‚ฌ๋ก€๋“ค์ด ์žˆ์ง€๋งŒ mm ์ƒํ™ฉ์—์„œ๋Š” stability๊ฐ€ ๋†’์ง€ ์•Š์•˜๋‹ค
    • Image