facebookresearch / vissl

VISSL is FAIR's library of extensible, modular and scalable components for SOTA Self-Supervised Learning with images.
https://vissl.ai
MIT License
3.25k stars 331 forks source link

Implement NNCLR #316

Open OlivierDehaene opened 3 years ago

OlivierDehaene commented 3 years ago

🌟 New SSL approach addition

Approach description

NNCLR: https://arxiv.org/abs/2104.14548

Self-supervised learning algorithms based on instance discrimination train encoders to be invariant to pre-defined transformations of the same instance. While most methods treat different views of the same image as positives for a contrastive loss, we are interested in using positives from other instances in the dataset. Our method, Nearest-Neighbor Contrastive Learning of visual Representations (NNCLR), samples the nearest neighbors from the dataset in the latent space, and treats them as positives. This provides more semantic variations than pre-defined transformations. We find that using the nearest-neighbor as positive in contrastive losses improves performance significantly on ImageNet classification, from 71.7% to 75.6%, outperforming previous state-of-the-art methods. On semi-supervised learning benchmarks we improve performance significantly when only 1% ImageNet labels are available, from 53.8% to 56.5%. On transfer learning benchmarks our method outperforms state-of-the-art methods (including supervised learning with ImageNet) on 8 out of 12 downstream datasets. Furthermore, we demonstrate empirically that our method is less reliant on complex data augmentations. We see a relative reduction of only 2.1% ImageNet Top-1 accuracy when we train using only random crops.

Architecture

image

Pseudocode

# f: backbone encoder + projection MLP
# g: prediction MLP
# Q: queue

for x in loader: # load a minibatch x with n samples
    x1, x2 = aug(x), aug(x) # random augmentation
    z1, z2 = f(x1), f(x2) # projections, n-by-d
    p1, p2 = g(z1), g(z2) # predictions, n-by-d

    NN1 = NN(z1, Q) # top-1 NN lookup, n-by-d
    NN2 = NN(z2, Q) # top-1 NN lookup, n-by-d

    loss = L(NN1, p2)/2 + L(NN2, p1)/2
    loss.backward() # back-propagate
    update(f, g) # SGD update
    update_queue(Q, z1) # Update queue with latest projection embeddings

def L(nn, p, temperature=0.1):
    nn = normalize(nn, dim=1) # l2-normalize
    p = normalize(p, dim=1) # l2-normalize

    logits = nn @ p.T # Matrix multiplication, n-by-n
    logits /= temperature # Scale by temperature

    n = p.shape[0] # mini-batch size
    labels = range(n)

    loss = cross_entropy(logits, labels)

    return loss

Open source status

OlivierDehaene commented 3 years ago

I am available to work on this :)

prigoyal commented 3 years ago

Hi @OlivierDehaene , this sounds amazing! go ahead with the implementation and let us know if we can help in any way! :)

OlivierDehaene commented 3 years ago

@prigoyal, @QuentinDuval,

There is something that puzzles me with the loss.

In SimCLR, we compare view1 with view2 but we also compare each view with itself and mask the diagonal. This is done to add negative pairs. Here, since view1 and view2 are not sampled from the same data augmentation (NNCLR uses the BYOL pipeline where view1 and view2 have different gaussian blur and solarization probabilities), it is not clear if we are allowed to do that. If it is not allowed, it effectively divides the batch size by 2 (EDIT: it divides the number of negative pairs by 2).

From the pseudocode above, we can infer that we should only compare view1 and view2. However, it is not clear if they omitted it to make the pseudocode more explicit or because it has an effect on performance.

TLDR: should we only compare views that are sampled from different data augmentations?

QuentinDuval commented 3 years ago

@prigoyal, @QuentinDuval,

There is something that puzzles me with the loss.

In SimCLR, we compare view1 with view2 but we also compare each view with itself and mask the diagonal. This is done to add negative pairs. Here, since view1 and view2 are not sampled from the same data augmentation (NNCLR uses the BYOL pipeline where view1 and view2 have different gaussian blur and solarization probabilities), it is not clear if we are allowed to do that. If it is not allowed, it effectively divides the batch size by 2 (EDIT: it divides the number of negative pairs by 2).

From the pseudocode above, we can infer that we should only compare view1 and view2. However, it is not clear if they omitted it to make the pseudocode more explicit or because it has an effect on performance.

TLDR: should we only compare views that are sampled from different data augmentations?

This is a good question (hard to answer because we do not have the implementation yet) so I went back to the paper, and the loss is defined as:

Screenshot 2021-05-10 at 11 44 29

They mention right after that they make it symmetric by adding the following term as well (but they also say that this has little influence on the performance):

Screenshot 2021-05-10 at 11 52 11

So my understanding is that we should have the same number of negative pairs as SimCLR.

HuangChiEn commented 2 years ago

Hello~ May i ask a question about the implementation of NNCLR in here ? Although the torch.select_index may cover the detail of implementation, I wonder that such operation is differentiable ?

In the NNCLR, the enoder, whose embedding is replaced by the support set, will not be updated in that step. So, it's means we will deteach the embedding before feed it into the query method of support set ?

Thanks for any suggestion..

Obsinaan commented 1 year ago

Please code is available?

HuangChiEn commented 1 year ago

Please code is available?

you can move to solo-learn, the more powerful SSL framework with readable repository. πŸ”— https://github.com/vturrisi/solo-learn

and say goodnight to vissl ~

Please code is available?

xiachenrui commented 1 year ago

Please code is available?

you can move to solo-learn, the more powerful SSL framework with readable repository. πŸ”— https://github.com/vturrisi/solo-learn

and say goodnight to vissl ~

Please code is available?

Thanks for your kind reply !