facebookresearch / vissl

VISSL is FAIR's library of extensible, modular and scalable components for SOTA Self-Supervised Learning with images.
https://vissl.ai
MIT License
3.25k stars 331 forks source link

Questions on constructing pairs with distributed batch size in SimCLR #474

Closed DianCh closed 2 years ago

DianCh commented 2 years ago

Hi! From the original SimCLR paper the performance benefits from large batch size, i.e., a large number of negative pairs for the NT-Xent loss:

Screen Shot 2021-11-21 at 12 52 27 PM

however, when we run training with DDP, the batch is divided into K sub-batches and the loss is calculated with only the samples from the sub-batch on each rank (i.e., each rank sees only N/K). In this case, how do we reproduce the original setting where the contrastive distribution is over 4096 examples?

Thank you!

iseessel commented 2 years ago

Hi @DianCh we all_gather the embeddings from each distributed rank: https://github.com/facebookresearch/vissl/blob/main/vissl/losses/simclr_info_nce_loss.py#L139.

Our config has 64 batchsize per gpu, 8 gpus, and 8 nodes, which is a global batchsize of 4096.

DianCh commented 2 years ago

Thank you @iseessel ! It's great to know that your gather implementation doesn't cut off the gradient, which I think answers exactly the main issue.

BoPang1996 commented 1 year ago

Hi, I do not understand why the GatherLayer function "all_reduce" the gradients in the backward function. In my understanding, for all the ranks, the "*grad" are the same, which is already carried all the gradients from all the samples. Thus this all_reduce op will enlarge the gradient by "dist.word_size()" times.