pytorch / ignite

High-level library to help with training and evaluating neural networks in PyTorch flexibly and transparently.
https://pytorch-ignite.ai
BSD 3-Clause "New" or "Revised" License
4.49k stars 605 forks source link

how to use all_gather in training loop? #2504

Open kkarrancsu opened 2 years ago

kkarrancsu commented 2 years ago

I have defined my train_step in the exact same way as in the cifar10 example. Is it possible to gather all of the predictions before computing the loss? I haven't seen examples of this pattern in the ignite examples (maybe I'm missing it?), but for my application, it is more optimal to compute the loss after aggregating the forward passes and targets run on multiple GPU's. This only matters when using DistributedDataParallel, since DataParallel automatically aggregates the outputs.

I see the idist.all_gather() function, but am unclear how to use it in a training loop.

sdesrozis commented 2 years ago

@kkarrancsu Thanks for your answer.

In general, idist.all_gather() can be used as long as the call is made collectively by all the processes. Therefore, you can use this method to gather the predictions in your training loop.

I can provide an example asap and maybe update the doc accordingly.

However, I'm not completely sure about your question. In fact, if you want to compute predictions in ddp, gather and back propagate from one proc, it won't work. You can check the internal design https://pytorch.org/docs/stable/notes/ddp.html#internal-design

kkarrancsu commented 2 years ago

@sdesrozis Thanks for your quick reply! Sorry if my initial question was unclear. As an example:

m = model()
m_dp = nn.DataParallel(m)
m_ddp = nn.DistributedDataParallel(m)

x = input # [batch_size, ...]
y_dp = m_dp(x)  # [batch_size, ...]
y_ddp = m_ddp(x) # [batch_size/ngpu, ...]

I'd like to gather all the y_ddp from all gpu's before computing a loss. I hope that makes the question clear?

sdesrozis commented 2 years ago

Thanks for the clarification. Would you like to use the loss as a metric ? Or would you want to call loss.backward() ?

kkarrancsu commented 2 years ago

I'd like to call loss.backward()

sdesrozis commented 2 years ago

Ok so I think it won't work even if you gather the predictions. The gathering operation is not an autodiff function so it will cut the graph computation. The forward pass creates some internal states that won't be gathered too.

Although I'm pretty sure that is answered in the PyTorch forum. Maybe I'm wrong though and I would be interested by a few discussions about this topic.

EDIT see here https://amsword.medium.com/gradient-backpropagation-with-torch-distributed-all-gather-9f3941a381f8

kkarrancsu commented 2 years ago

@sdesrozis Thanks - I will investigate based on your link and report back.

sdesrozis commented 2 years ago

@sdesrozis Thanks - I will investigate based on your link and report back.

Good ! Although I’m doubtful about the link… Interesting by your feedback.

vfdev-5 commented 2 years ago

@kkarrancsu can you provide a bit more details on what exactly you would like to do ? In DDP, data is distributed to N processes and model is cloned. When we do the forward pass each process obtains predictions y_preds = m_ddp(x) on its data chunk and using a loss function and loss.backward() we can compute gradients that are finally sum up and applied to the model internally by pytorch DDP model wrapper.

As for distributed autograd, you can check as well : https://pytorch.org/docs/stable/rpc.html#distributed-autograd-framework

kkarrancsu commented 2 years ago

Hi @vfdev-5, sure.

We are using the Supervised Contrastive loss to train an embedding. In Eq. 2 of the paper, we see that the loss depends on the number of samples used to compute it (positive and negative).

My colleague suggested to me that it is more optimal to compute the loss considering all examples (the entire batch), rather than considering batch/ngpu samples (which is what would happen when using DDP and computing loss locally to each GPU). This is because the denominator in SupConLoss is computing the loss of negative samples, and by first aggregating all of the negative samples across gpus, you would get a more accurate loss.

sdesrozis commented 2 years ago

Ok I understand. You should have a look to a distributed implementation of SimCLR. See for instance

https://github.com/Spijkervet/SimCLR/blob/cd85c4366d2e6ac1b0a16798b76ac0a2c8a94e58/simclr/modules/nt_xent.py#L7

This might give you some inspiration.

lxysl commented 10 months ago

Ok I understand. You should have a look to a distributed implementation of SimCLR. See for instance

https://github.com/Spijkervet/SimCLR/blob/cd85c4366d2e6ac1b0a16798b76ac0a2c8a94e58/simclr/modules/nt_xent.py#L7

This might give you some inspiration.

This code is not so correct. Please check this issue: https://github.com/Spijkervet/SimCLR/issues/30 and my pr: https://github.com/Spijkervet/SimCLR/pull/46.