facebookresearch / vissl

VISSL is FAIR's library of extensible, modular and scalable components for SOTA Self-Supervised Learning with images.
https://vissl.ai
MIT License
3.25k stars 332 forks source link

Does the MoCo implementation do shuffled Batch Norm when run on a single GPU? #576

Closed nemtiax closed 1 year ago

nemtiax commented 1 year ago

In the MoCo paper section 3.3, they say:

"We resolve this problem by shuffling BN. We train with multiple GPUs and perform BN on the samples independently for each GPU (as done in common practice). For the key encoder fk, we shuffle the sample order in the current mini-batch before distributing it among GPUs (and shuffle back after encoding); the sample order of the mini-batch for the query encoder fq is not altered. This ensures the batch statistics used to compute a query and its positive key come from two different subsets. This effectively tackles the cheating issue and allows training to benefit from BN."

In principle, one could simulate this effect by splitting up the batch norm calculation into a small number of groups across the batch axis when on a single GPU. Does the VISSL implementation do any special handling for a single GPU, or do I risk falling victim to the model learning to "cheat" using the intra-batch communication offered by batch norm?

I see that there is a NORM config option available when setting up a model (https://vissl.readthedocs.io/en/v0.1.5/vissl_modules/models.html), it seems to me that LayerNorm should avoid the issue described in the MoCo paper, as it doesn't communicate across samples in the batch. Is there any reason I could not safely swap in LayerNorm in the MoCo config? I'd be curious to know if anyone has tried this, or if there is a more straightforward approach to this problem.

QuentinDuval commented 1 year ago

Hi @nemtiax,

Thanks a lot for your interest in VISSL and your question :)

To the best of my understanding, the issue described in the MoCo paper is based on the ability to "cheat" based on using local batch statistics rather than global batch statistics. We means that the issue only exist if you have several GPUs each using different statistics because their batch are different. The moment we have a single GPU for training, this issue is not an issue anymore as the global statistics are the same as the local statistics.

This is why in VISSL, we only do the shuffling and un-shuffling when distributed training is enabled: https://github.com/facebookresearch/vissl/blob/main/vissl/hooks/moco_hooks.py#L164 https://github.com/facebookresearch/vissl/blob/main/vissl/hooks/moco_hooks.py#L171

So in short, there should be no issues with MoCo on 1 GPU, at least no problem due to this (there could be issues with the fact that the batch size is too small to get good results, as it is the case for SimCLR).

As for replacing the BatchNorm with LayerNorm, I have never tried it personally for MoCo, but LayerNorm has indeed risen in popularity in particular in vision transformers. It could very well be working here.

Thank you, Quentin

nemtiax commented 1 year ago

Great, that makes sense to me. As you say, the MoCo appendix suggests that the method of "cheating" through BN is to identify which sub-batch contains the target, which would not be an issue for a single GPU.

Thanks for your help!

nemtiax commented 1 year ago

Documenting what I found in case someone lands on this issue via Google search in the future. I trained MoCo on Cifar10 using one GPU with both BN and LN. LN seems get to stuck in degenerate solution for many epochs before eventually breaking out and beginning to learn. As might be expected, even after recovering and beginning to learn, the final performance of the LN model is far worse. It loses like 25% accuracy compared to the BN model (~86% -> ~62%) after fine-tuning a linear classifier on the descriptors.

Training loss, orange is BN, blue is LN: image

Gradients on a representative sample of model weights. Note that the LN model has very small gradients for the first 60 or so epochs:

image

I don't yet have an explanation for exactly what is going wrong in the LN case, but my hypothesis is that BN helps to prevent the model from pushing all the images in a batch to the same descriptor, and so makes it easier to started learning.