image

paper , code

Introduction

Bootstrap* Your Own Latent(BYOL)์€ online network, target network ๋‘ ๋„คํŠธ์›Œํฌ๊ฐ€ ์ƒํ˜ธ์ž‘์šฉํ•˜๊ณ  ์„œ๋กœ ํ•™์Šตํ•˜๋„๋ก ์„ค๊ณ„๋˜์—ˆ๋‹ค. ํ•œ ์ด๋ฏธ์ง€๋ฅผ ์–ด๊ทธ๋ฉ˜ํ…Œ์ด์…˜ ์‹œํ‚จ ๊ฒƒ์„ online network์— ๋„ฃ์–ด์„œ ๊ฐ™์€ ์ด๋ฏธ์ง€๋ฅผ ๋‹ค๋ฅด๊ฒŒ ์–ด๊ทธ๋ฉ˜ํ…Œ์ด์…˜ ์‹œํ‚จ ๊ฒƒ์„ target network์— ๋„ฃ์—ˆ์„ ๋•Œ์˜ ํ‘œํ˜„์„ ๋‚˜ํƒ€๋‚ด๋„๋ก ํ•™์Šตํ•œ๋‹ค. ๋™์‹œ์— ์šฐ๋ฆฌ๋Š” online ๋„คํŠธ์›Œํฌ์˜ slow-moving average๋กœ target ๋„คํŠธ์›Œํฌ๋ฅผ ํ•™์Šต์‹œํ‚จ๋‹ค. ํ˜„์žฌ SOTA ๋ชจ๋ธ๋“ค์€ negative pair๋ฅผ ์‚ฌ์šฉํ•˜์ง€๋งŒ, BYOL์€ ์ด ์—†์ด ์ƒˆ๋กœ์šด SOTA๋ฅผ ๋‹ฌ์„ฑํ•˜์˜€๋‹ค.

*bootstrap์€ ML์šฉ์–ด๊ฐ€ ์•„๋‹ˆ๋ผ ๊ทธ ์ž์ฒด์˜ ๋œป์ธ to improve your situation or become more successful, without help from others or without advantages that others have๋กœ ์“ฐ์˜€๋‹ค.

image
  • ์ด์ „์˜ ์—ฐ๊ตฌ๋“ค์€ pseudo-label์„ ์“ฐ๊ฑฐ๋‚˜, cluster indicies๋ฅผ ์“ฐ๊ฑฐ๋‚˜, handful label์„ ์ผ์ง€๋งŒ, ์šฐ๋ฆฌ์˜ ์—ฐ๊ตฌ๋Š” ๋ฐ”๋กœ representation์„ bootstrapํ•œ๋‹ค.
  • ์šฐ๋ฆฌ์˜ ์—ฐ๊ตฌ๋Š” negative pair๋ฅผ ์“ฐ์ง€ ์•Š์•„ ์ด๋ฏธ์ง€ ์–ด๊ทธ๋ฉ˜ํ…Œ์ด์…˜์— ๊ฐ•๊ฑดํ•˜๋‹ค.
  • #9 ๊ฐ™์€ ๋ฐฉ๋ฒ•๋ก ๋“ค์€ ์ด๋ฏธ์ง€์™€ ์–ด๊ทธ๋ฉ˜ํŠธ๋œ ์ด๋ฏธ์ง€๋“ค์„ ๊ฐ™์€ ์ด๋ฏธ์ง€๋กœ ์˜ˆ์ธกํ•˜๋ฉด์„œ ํ•™์Šต๋˜์—ˆ๋Š”๋ฐ, representation space์— prediction problem์„ ์ฃผ๋ฉด representation collapse๊ฐ€ ์ƒ๊ธด๋‹ค. ์ด๋ฅผ ๋ฐฉ์ง€ํ•˜๊ธฐ ๋’คํ•ด, ๊ฐ™์€ ์ด๋ฏธ์ง€๋ฅผ ์–ด๊ทธ๋ฉ˜ํŠธํ•œ ๊ฒƒ๊ณผ ๋‹ค๋ฅธ ์ด๋ฏธ์ง€๋ฅผ ์–ด๊ทธ๋ฉ˜ํŠธ ํ•œ๊ฒƒ์˜ ์ฐจ์ด๋ฅผ ์˜ˆ์ธกํ•˜๋Š” ๋ฐฉ๋ฒ•๋ก ์„ ์ ์šฉํ•˜์˜€์œผ๋‚˜, ์ด๋Š” ๊ต‰์žฅํžˆ ๋งŽ์€ negative sample์„ ์ œ์‹œํ•˜์—ฌ์•ผํ•˜๋Š” ํ•œ๊ณ„๊ฐ€ ์žˆ๋‹ค.
  • negative sample ์—†์ด collapse๋ฅผ ๋ฐฉ์ง€ํ•˜๊ธฐ ์œ„ํ•˜์—ฌ, ๋‹จ์ˆœํ•œ ํ•ด๊ฒฐ์ฑ…์€ ๊ณ ์ •๋œ ๋žœ๋ค์˜ ๋„คํŠธ์›Œํฌ๋ฅผ ์šฐ๋ฆฌ์˜ ์˜ˆ์ธก์„ ํ•˜๊ธฐ ์œ„ํ•œ ํƒ€๊ฒŸ์ด ๋˜๋„๋ก ๋งŒ๋“œ๋Š” ๊ฒƒ์ด๋‹ค. ์ด๋Ÿฌํ•œ ๋ฐฉ๋ฒ•์€ collapse๋ฅผ ๋ฐฉ์ง€ํ•˜๊ธด ํ•˜์ง€๋งŒ, ์„ฑ๋Šฅ์€ ๋‚ฎ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ๋†€๋ผ์šด ์ ์€ ๊ทธ๋ƒฅ random initialized network๋ฅผ linear evaluation ํ•˜๋Š” ๊ฒƒ์€ 1.4%์˜ ์ •ํ™•๋„๋ฅผ ๊ฐ€์ง€์ง€๋งŒ, fixed random initialized network์˜ output์„ ์˜ˆ์ธกํ•˜๊ฒŒ ํ•˜๋ฉด 18.8%์˜ ์ •ํ™•๋„๋ฅผ ์–ป๋Š”๋‹ค. ์ด ์‹คํ—˜์ด BYOL์˜ motivation์ด ๋˜์—ˆ๋‹ค.
  • representation(=target network)์ด ์ฃผ์–ด์กŒ์„ ๋•Œ, ์šฐ๋ฆฌ๋Š” ์ƒˆ๋กœ์šด online network๋ฅผ target representation์„ ์˜ˆ์ธกํ•˜๋„๋ก ํ•™์Šตํ•  ์ˆ˜์žˆ๋‹ค. ๊ทธ๋กœ๋ถ€ํ„ฐ ์šฐ๋ฆฌ๋Š” ์ด๋Ÿฌํ•œ ์ ˆ์ฐจ๋ฅผ ๋ฐ˜๋ณตํ•จ์— ๋”ฐ๋ผ ๋” ๋†’์€ ํ€„๋ฆฌํ‹ฐ์˜ ํ‘œํ˜„์„ ํ•™์Šตํ•  ์ˆ˜ ์žˆ๊ณ , ๋” ํ•™์Šตํ•˜๊ธฐ ์œ„ํ•ด ๋‹ค์Œ์˜ online network๋ฅผ ์ƒˆ๋กœ์šด target network๋กœ ์„ค์ •ํ•˜์—ฌ ํ•™์Šตํ•  ์ˆ˜ ์žˆ๋‹ค. ์‹ค์ œ๋กœ๋Š” online network์˜ moving exponential average๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ bootstrap ์ ˆ์ฐจ๋ฅผ ๋ฐŸ์•˜๋‹ค.

BYOL

image
  • online network๋Š” encoder, projector, predictor๋กœ ๊ตฌ์„ฑ๋˜์–ด ์žˆ๊ณ  weight \theta๋ฅผ ๊ฐ€์ง€๊ณ  ์žˆ๋‹ค.
  • target network๋Š” online๊ณผ ๊ฐ™์€ ๊ตฌ์กฐ๋ฅผ ๊ฐ€์ง€๊ณ  ์žˆ์ง€๋งŒ, ๋‹ค๋ฅธ weight์ธ \psi๋ฅผ ๊ฐ€์ง€๊ณ  ์žˆ๊ณ , online network์˜ target์„ ์ œ๊ณตํ•˜๋Š” ์—ญํ• ์„ ํ•œ๋‹ค. ์ด๋•Œ, ํŒŒ๋ผ๋ฏธํ„ฐ \psi๋Š” online parameter \theta์˜ moving average์ด๋‹ค.
image

ํ•œ ์ด๋ฏธ์ง€์— ๋Œ€ํ•ด ์–ด๊ทธ๋ฉ˜ํ…Œ์ด์…˜ ์‹œํ‚จ \nu, \nu’๋ฅผ ๋งŒ๋“ค๊ณ  ๊ฐ๊ฐ์˜ ๋„คํŠธ์›Œํฌ๋ฅผ ํƒœ์šด๋‹ค. ์ดํ›„ online์˜ ๋งˆ์ง€๋ง‰ prediction์˜ output์„ target์˜ projection ๊ฒฐ๊ณผ์™€ MSE๋ฅผ ๊ตฌํ•œ๋‹ค. image

์ดํ›„ ๋‹ค์‹œ online network์— ์–ด๊ทธ๋ฉ˜ํ…Œ์ด์…˜ ๋œ \nu, \nu’ ๋ฐ˜๋Œ€๋กœ ๋„ฃ๊ณ , loss๋ฅผ ๊ตฌํ•œ๋‹ค. ๊ทธ๋ฆฌ๊ณ  loss๋ฅผ ํ•ฉํ•œ๋’ค \theta์— ๋Œ€ํ•ด์„œ๋งŒ minimize๋ฅผ ํ•œ๋‹ค.
image

Implementation details

  • Image Augmentation #9 ๊ณผ ๊ฐ™์€ augmentation set์„ ์‚ฌ์šฉ. ๋žœ๋ค ํŒจ์น˜๋กœ select 224 x 224 random horizontal flip …
  • Architecture ResNet-50 for encoder, average pooling for representaion layer, MLP(4096 -> ReLU -> 256) for prediction layer. no batch norm.
  • Optimization : LARS, cosine decay, …

Result

  • linear evaluation in ImageNet image

  • Finetuning(=Semi-supervised training) in ImageNet image

  • Transfer to other classification task image

  • Transfer to other vision task image

Ablation

image

simCLR๊ณผ ๋น„๊ตํ•ด๋ดค์„ ๋•Œ batch_size๋ฅผ ์ค„์ด๊ณ  augmentation์„ ์ค„์ž„์— ๋”ฐ๋ผ ์„ฑ๋Šฅ ํ•˜๋ฝ์ด ๋œํ–ˆ๋‹ค.

image

moving average๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ด ์˜๋ฏธ๊ฐ€ ์žˆ์—ˆ๋‹ค.

image

target netork๋ฅผ ๋‘๋Š” ๊ฒƒ์ด ์˜๋ฏธ๊ฐ€ ์žˆ์—ˆ๋‹ค.