image

paper

TL;DR

  • I read this because.. : CS330 ๊ฐ•์˜์—์„œ ์–ธ๊ธ‰๋จ. #118 ์—์„œ๋„ Perceiver ์‚ฌ์šฉํ–ˆ๋‹ค๊ณ  ํ•ด์„œ IO ๋ถ™์€๊ฑด ๋ญ๊ฐ€ ๋‹ค๋ฅด์ง€ ํ•˜๊ณ  ๋ด„
  • task : image classification, language modeling, optical flow, StarCraft II, …
  • problem : ๊ฐ๊ฐ์˜ ๋„๋ฉ”์ธ / ํƒœ์Šคํฌ์— ๋Œ€ํ•œ ๋ชจ๋ธ๋“ค์ด ๊ฐ๊ฐ ์žˆ์Œ. ํ•˜๋‚˜์˜ NN์œผ๋กœ ์ฒ˜๋ฆฌํ•˜๋ฉด ์ธ์ƒ์ด ํŽธํ• ํ…๋ฐ
  • idea : transformer encoder-decoder ๊ตฌ์กฐ์ธ๋ฐ Perceiver๊ตฌ์กฐ(CA๋กœ input modality๊ฐ€ ๋“ค์–ด๊ฐ€๋Š” ํ˜•ํƒœ) + output query๋ฅผ ์‚ฌ์šฉํ•˜์ž
  • input : (encoder) N x D์ฐจ์›์˜ latent array (decoder) positional embedding or task embedding
  • output : (encoder) context vector (decoder) class(for image classification), token id(for MLM), …
  • architecture : ๊ทผ๋ฐ encoder๊ฐ€ Perceiver ํ˜•ํƒœ(ํ…์ŠคํŠธ, ์ด๋ฏธ์ง€, ๋น„๋””์˜ค๋“ฑ์ด CA๋กœ ๋“ค์–ด๊ฐ€๋Š”) / decoder๋Š” encoder context vector๋ž‘ output query๊ฐ„์˜ CA๋งŒ ์žˆ๋Š”
  • objective : ๊ฐ ํƒœ์Šคํฌ์— ๋งž๋Š” ๋ชฉํ‘œ ํ•จ์ˆ˜
  • baseline : GLUE(BERT), Image Classification(ViT-B), Optical Flow(PWCNet, RAFT), StarCraft(Transformer), AudioSet Classification(Perceiver IO)
  • data : English Wikipedia + C4, ImageNet, JFT….
  • result : GLUE์—์„œ BERT๋ž‘ ๋™์ผ FLOPS ๋Œ€๋น„ ๋” ๋‚˜์€ ์„ฑ๋Šฅ. Optical flowใ…‚๋„ ๋ฒ ์ด์Šค๋ผ์ธ ๋Œ€๋น„ ๋ช‡๊ฐœ Metric ๋Œ€๋น„ ์ข‹์€ ์„ฑ๋Šฅ. ๋‚˜๋จธ์ง€๋Š” ์„ฑ๋Šฅ์ด ๊ทธ๋Ÿญ์ €๋Ÿญ์ด์ง€ best๋Š” ์•„๋‹˜.
  • contribution : ์ƒ๋‹นํžˆ ๋งŽ์€ modality์— ๋Œ€ํ•ด test. decoder์— task embedding / PE embedding์„ ๋„ฃ๋Š” ๋ฐฉ์‹์ด contribution point๊ฐ€ ์•„๋‹Œ๊ฐ€?! ๋‚˜๋จธ์ง€๋Š” ๋ง‰ ์ƒˆ๋กœ์šด ๋А๋‚Œ์€ ์•„๋‹Œ๋“ฏ
  • etc. :

Details

Architecture

image

Output Queries

image
  • ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜ ๊ฐ™์€ classification์€ ๊ทธ๋ƒฅ task embedding
  • multi task์ธ ๊ฒฝ์šฐ task embedding ๋“ค ์—ฌ๋Ÿฌ๊ฐœ
  • MLM์˜ ๊ฒฝ์šฐ 2048๊ฐœ์˜ Positional Embedding

์•„ํ‚คํ…์ณ ์„ธ๋ถ€

image

Result

  • task๋“ค image

  • GLUE image

introduction์—์„œ๋„ ๊ทธ๋ ‡๊ณ  UTF-8 byte๋กœ ํ•œ ๊ฑธ ๊ฐ•์กฐํ•˜๋Š”๋ฐ ์ด๊ฒƒ์ž์ฒด๋Š” contribution์ธ์ง„ ๋ชจ๋ฅด๊ฒ ๊ณ (BBPE ๊ฐ™์€ ์„ ํ–‰์—ฐ๊ตฌ๊ฐ€ ์žˆ์œผ๋‹ˆ?) ์–˜ ๋•Œ๋ฌธ์— max_len์ด ๊ธธ์–ด์ง€๋Š”๋ฐ $$O(n**2)$$์ด ์•ˆ๋˜๊ณ  ๊ตฌ์กฐ์ƒ linearํ•˜๊ฒŒ ๋ณต์žก๋„๊ฐ€ ๋Š˜์–ด๋‚˜๋Š”๊ฒŒ contribution์ธ๋“ฏ! ์ด ํ‘œ์—์„œ๋„ ๊ทธ๋ ‡๊ณ  BERT๋ณด๋‹ค ํŒŒ๋ผ๋ฏธํ„ฐ๋Š” ํ›จ ํฐ๋ฐ FLOPS๊ฐ€ ๋” ๋‚ฎ์Œ. ํŒŒ๋ผ๋ฏธํ„ฐ๋Š” hidden dim์„ ์ค„์ด๊ณ  Depth๋ฅผ ์—„์ฒญ ๋Š˜๋ ธ๋„น ์ด๊ฑด ์™œ์ง€ BERT๋ž‘ ๋น„๊ตํ–ˆ์„ ๋•Œ max_len์„ 512 -> 2048๋กœ ๋Š˜๋ ธ๊ณ  vocab size๋Š” 256๋กœ ์ค„์˜€๋‹ค๊ณ  ํ•จ.

  • image classification image

ViT-B/16์™€ ๋น„๊ตํ–ˆ์„ ๋•Œ ๋”ฑํžˆ ์ข‹์•„๋ณด์ด์ง„ ์•Š๋Š”๋Ž….. ์ผ๋‹จ ViT๋ณด๋‹จ ์•ˆ์ข‹์€๋“ฏ ์„ฑ๋Šฅ JFT pretraining ํ•œ๊ฒŒ 86.4์ ์ธ๋ฐ ViT-H/14์˜ 88.6์ ์ด๋ž‘ ์ฐจ์ด๊ฐ€ ์ข€ ์žˆ์–ด๋ณด์ธ๋‹น(ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜๋Š” 1/3์ด๊ธด ํ•จ) ๊ฒฐ๊ตญ ์ตœ์ข…์ ์ธ best ์„ฑ๋Šฅ์€ Conv ๋ถ™์ธ ๊ฒƒ๋„ ์ข€ ๊ทธ๋Ÿผ ๊ทธ ์™ธ ์ผ๋‹จ ์ „์ž‘ Perceiver๋ณด๋‹ค ์ข‹์•„์กŒ๋‹ค ์ •๋„ ๋ณผ ์ˆ˜ ์žˆ๋Š”๋“ฏ?

  • AudioSet Classification image

  • StarCraft II image