vturrisi / solo-learn

solo-learn: a library of self-supervised methods for visual representation learning powered by Pytorch Lightning
MIT License
1.41k stars 182 forks source link

ddp implementation of the losses #258

Closed priyamdey closed 2 years ago

priyamdey commented 2 years ago

Hi there. I am new to ddp training. Reading about the collectives provided by Pytorch ddp, I'm wondering why are we doing all_gather vs gather? all_gather collects outputs from all processes and makes it available to the processes vs gather which just sends it to the master process. Can't we do it only on the master process (i.e., use only gather)? It would be very helpful if you can paint a rough picture of what the overall process looks like and where my understanding is wrong. Thanks.

vturrisi commented 2 years ago

Hi. When we are doing contrastive learning, we benefit from having as much negatives as possible. Because if this, we gather all the instances in all processes so that we have as much negatives as possible for all gpus.

priyamdey commented 2 years ago

Okay got it thanks.

priyamdey commented 2 years ago

One doubt in the simclr ddp logic. The z and indexes are first gathered here and here. Then, z is normalized and gathered again (here)?

vturrisi commented 2 years ago

@priyamdey thanks for pointing this out. Indeed we shouldn't be gathering inside simclr method and just inside the loss. This likely also fixes what #253 was pointing out. I created #259 to fix this.

priyamdey commented 2 years ago

One doubt regarding the compute overhead. If we're gathering all local z's and sending it to all the processes, everyone will have complete z and will compute the complete contrastive loss. The local gradients on backward will then be no more local and actually be the true gradients. But DDP does a gather and averaging of the gradients again in loss.backward() and updates each local model's parameters with this gradient. This last step is essentially a redundant one, isn't it?

vturrisi commented 2 years ago

We just gather the negatives, so that we have a (batch views) by (batch views gpus) similarity matrix. The reason for that is that there's no way to backdrop gradients across gpus. Since we average in the batch dimension, if we gathered as (batch views gpus) by (batch views * gpus), we are effectively just scaling down our gradients.

priyamdey commented 2 years ago

I see. But in your GatherLayer class, it's mentioned in the docstring that it supports backprop for the gradients across processes?

vturrisi commented 2 years ago

If we are in some gpu, we have both the data from that GPU and from the other GPUs. It is possible to compute the gradients from the instances on that GPU, but not the gradients of the other instances. DDP just raises an exception and doesn't let you backprop even the gradients of the instances in that device. GatherLayer bypasses this exception

vturrisi commented 2 years ago

I re-read what I'd written and probably it wasn't clear enough. Assume that you have x1 and x2 in gpu 1 (where x1 is from that gpu and x2 was gathered from another device). You can compute the gradients for x1 but not for x2. Ddp won't even allow you to do the first because it will check that you have instances that you can't compute the gradients for. GatherLayer will basically just bypass this check and compute gradients for x1.

priyamdey commented 2 years ago

I see. Thanks for clarifying. It would be quite helpful if you can point to some good ddp resource references where one can read about how to implement custom ddp logic taking into account these caveats and general ddp understanding?

vturrisi commented 2 years ago

@priyamdey I think you can start by taking a look at the DDP tutorial for pytorch. Browsing other repositories is also a good idea.