Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
27.42k stars 3.29k forks source link

Support uneven DDP inputs with pytorch model.join #3325

Open edenlightning opened 3 years ago

edenlightning commented 3 years ago

See more details: https://github.com/pytorch/pytorch/issues/38174

cc @borda @tchaton @rohitgr7 @akihironitta @awaelchli

stale[bot] commented 3 years ago

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

carmocca commented 3 years ago

Interested in this issue! Hopefully some progress is done soon :+1:

xvr-hlt commented 3 years ago

Interested in this also :)

rohan-varma commented 3 years ago

Is there any progress on this issue? Happy to help in any way.

edenlightning commented 3 years ago

@rohan-varma that would be great!! Want to try and submit a draft PR? And we can help from there?

rohan-varma commented 3 years ago

@edenlightning Sounds good, I also pinged the slack channel for any feedback/discussions.

alanhdu commented 3 years ago

We'd also be very interested in this feature. Let us know if there's anything I can do to help!

rohan-varma commented 3 years ago

The PR https://github.com/PyTorchLightning/pytorch-lightning/pull/5141 is ready for review, in case anyone wants to take a look.

stale[bot] commented 3 years ago

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

ananthsub commented 3 years ago

I discussed this more with @rohan-varma - DDP join docs: https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel.join

This module currently does not support custom distributed collective operations in the forward pass, such as SyncBatchNorm or other custom defined collectives in the model’s forward pass.

As the LightningModule is wrapped in another module which is then wrapped with DDP, the LightningModule's training_step becomes the forward pass run by the DDP wrapped module: https://github.com/PyTorchLightning/pytorch-lightning/blob/d471fa30b3bf95cfe601014bac544754067241ca/pytorch_lightning/plugins/training_type/ddp.py#L223-L227

As a result, any collective call (such as metric syncing or all_gather) that happens during training step will cause this to not work. Therefore I lean towards closing this out given the caveats. @awaelchli @justusschock what do you think?

awaelchli commented 3 years ago

agree, I also don't see how this can be supported at the moment.

JiesiZhao077 commented 2 years ago

Also blocked by the duplicate validation data, which causes the validation result to be incorrect. Is there any workaround for this issue? Like only running validation on one gpu?

carmocca commented 2 years ago

@JiesiZhao077 Running in one GPU would be the easiest and safest option. You can also set this DistributedSampler which would also solve the issue, but introduces the risk of deadlocks.

https://github.com/PyTorchLightning/pytorch-lightning/blob/45f6a3b1758f88af7fd776915539800cbc0137a9/pytorch_lightning/overrides/distributed.py#L80

We use automatically with trainer.predict. You could try and use it with trainer.validate as well.

Closing this issue as there's no clear way forward at the moment.

tmbdev commented 2 years ago

@ananthsub commented 6 hours ago

However, the manner in which join tracks collectives can quickly run into issues with other collectives that run in the forward pass / training_step.

In PyTorch, the "with Join" construct is used as a simple wrapper around training steps. It should work in simple cases even if there are more complex cases where it doesn't work.

So, why not simply add an option to the trainer that enables wrapping the invocations of training_step with with Join? That should be pretty straightforward, and it would leave it up to users to determine when with Join is the right thing to use and when it doesn't work.

awaelchli commented 2 years ago

So, why not simply add an option to the trainer that enables wrapping the invocations of training_step with with Join?

The join here is specific to pytorch DDP. If it was implemented, it would have to live inside the DDP plugin/strategy. For simple cases it may work, but no collective calls are allowed except the ones under DDP.forward()/DDP.backward() if I understand correctly.

If we did want to do it "correctly", we would probably have to set throw_on_early_termination=True and then we must handle the error in all custom collective calls, including the ones in torchmetrics. I don't know if that would work, but it's probably not feasible.

carmocca commented 2 years ago

To recap, the plan would be:

Some sources: https://pytorch.org/docs/stable/distributed.algorithms.join.html#torch.distributed.algorithms.Join https://pytorch.org/tutorials/advanced/generic_join.html

awaelchli commented 2 years ago

Is there a benefit to doing it for validation_step and test_step? Probably not

I assume there is, if collectives are being used. For example, sync_dist=True in self.log or similar. However, we don't wrap the model in ddp during val and test, so join won't be available anyways.

When the feature is enabled, we don't automatically use the generic DistributedSampler as we wouldn't want to duplicate data to make inputs even.

https://github.com/pytorch/pytorch/issues/49180 is great! Hopefully this will clarify the drop_last argument which has a slightly misleading/incomplete description :) We would indeed need the UnrepeatedDistributedSampler.

otaj commented 1 year ago

Hi, everyone, I'm gathering information on what is needed in order to support this properly.

  1. Use torch.distributed.algorithms.Join (https://pytorch.org/docs/stable/distributed.algorithms.join.html) as a context manager in which is the model run.
  2. Use UnrepeatedDistributedSamplerWrapper
  3. Check for all modules, that could use syncing (such as SyncBatchNorm) and have them as arguments to the Join context manager from 1.
  4. Figure out what to with calls to self.log(..., sync_dist=True)

Is that it? cc @awaelchli, @carmocca.

justusschock commented 1 year ago

@otaj Almost.

Additionally, all metrics from torchmetrics would have to be considered as well as they are also capable of issuing syncs on their own. And in general, the user can run arbitrary syncing calls within each of the steps which have to be considered as well (which will be the trickiest part I guess)

otaj commented 1 year ago

oh, those torchmetrics are going to be fun... :sweat_smile: I think capturing user calls can be solved with yet another contextmanager (our, custom one), what do you think?

justusschock commented 1 year ago

if we can capture user calls with that, it might work similarly with torchmetrics. So let's ignore those metrics for now and if you got a working solution for everything else, I'm sure we'll manage to integrate metrics with that :D

Borda commented 1 year ago

let's check the option with LigthingLite first 🦦

awaelchli commented 1 year ago

Here is the corresponding issue as suggested in planning: #14635

yygle commented 1 month ago

hi, any updates of this issue?