dhkim0225 / 1day_1paper

read 1 paper everyday (only weekday)
54 stars 1 forks source link

[64] Pushing the limits of self-supervised ResNets: Can we outperform supervised learning without labels on ImageNet? (RELICv2) #92

Open dhkim0225 opened 2 years ago

dhkim0225 commented 2 years ago

paper

image

ResNet50 에 width [1x, 2x, 4x] , ResNet200 width 2x 결과. ResNet50 에서 label 없이 ImageNet을 뚫은 첫 결과.

RELIC v1

image

style과 content와 이미지-레이블의 인과 그래프에 대한 가정.

  1. data 는 style 과 content 로 나뉜다.
  2. content 만 unknown downstream task 와 관련 있다.
  3. content 와 style 은 independent 하다.

style에 대해 invariant한 representation을 self supervision. Augmentation을 style에 대한 intervention으로 사용.

결국 한 이미지에 대해 aug 한 다음에 KLD ㅋㅋ

RELIC v2

Saliency map 을 활용한 학습방식 DeepUSPS 라는 unsupervised saliency map model 을 imagenet 의 일부를 이용해서 학습시킴. ema network 인 target network, online network 두 가지가 사용됨

large crop 에 대해서 특정 확률로 (베르누이) background 를 없애버림. (saliency map 적용)

(1) target network g 에 large-crop 을 통과시킨 output (2) online network f 에 large-crop 을 통과시킨 output (3) online network f 에 small-crop 을 통과시킨 output

(1), (2) 로 loss 구하고, (1), (3) 으로 loss 구함 RELICv2 Loss 는 단순히 ContrastiveNLL 과 KLD 를 더해주는 방식이다. image

ContrastiveNLL image

pi function 은 similarity function. f, g network output 에다가 각각 다른 multi-layer head 를 붙여서 나온 녀석들 비교. g network 는 f 의 EMA 사용. image

KLD (sg == stop gradient) image

"""
f_o: online network: encoder + comparison_net
g_t: target network: encoder + comparison_net
gamma: target EMA coefficient
n_e: number of negatives
p_m: mask apply probability
"""

for x in batch: # load a batch of B samples
    # Apply saliency mask and remove background
    x_m = remove_background(x)
    for i in range(num_large_crops):
        # Select either original or background-removed
        # Image with probability p_m
        x = Bernoulli(p_m) ? x_m : x

        # Do large random crop and augment
        xl_i = aug(crop_l(x))

        ol_i = f_o(xl_i)
        tl_i = g_t(xl_i)

    for i in range(num_small_crops):
        # Do small random crop and augment
        xs_i = aug(crop_s(x))

        # Small crops only go through the online network
        os_i = f_o(xs_i)

    loss = 0
    # Compute loss between all pairs of large crops
    for i in range(num_large_crops):
        for j in range(num_large_crops):
            loss += loss_relicv2(ol_i, tl_j, n_e)

    # Compute loss between small crops and large crops
    for i in range(num_small_crops):
        for j in range(num_large_crops):
            loss += loss_relicv2(os_i, tl_j, n_e)

    scale = (num_large_crops + num_small_crops) * num_large_crops
    loss /= scale

    # Compute grads, update online and target networks
    loss.backward()
    update(f_o)
    g_t = gamma * g_t + (1 - gamma) * f_o

Results

ImageNet

image

Data fraction 이 적어도 잘 된다. image

JFT pretrain image

Transfer Learning

image

Semantic Segmentation

image