Open edenlightning opened 3 years ago
This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!
Interested in this issue! Hopefully some progress is done soon :+1:
Interested in this also :)
Is there any progress on this issue? Happy to help in any way.
@rohan-varma that would be great!! Want to try and submit a draft PR? And we can help from there?
@edenlightning Sounds good, I also pinged the slack channel for any feedback/discussions.
We'd also be very interested in this feature. Let us know if there's anything I can do to help!
The PR https://github.com/PyTorchLightning/pytorch-lightning/pull/5141 is ready for review, in case anyone wants to take a look.
This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!
I discussed this more with @rohan-varma - DDP join docs: https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel.join
This module currently does not support custom distributed collective operations in the forward pass, such as SyncBatchNorm or other custom defined collectives in the model’s forward pass.
As the LightningModule is wrapped in another module which is then wrapped with DDP, the LightningModule's training_step
becomes the forward pass run by the DDP wrapped module: https://github.com/PyTorchLightning/pytorch-lightning/blob/d471fa30b3bf95cfe601014bac544754067241ca/pytorch_lightning/plugins/training_type/ddp.py#L223-L227
As a result, any collective call (such as metric syncing or all_gather
) that happens during training step will cause this to not work. Therefore I lean towards closing this out given the caveats. @awaelchli @justusschock what do you think?
agree, I also don't see how this can be supported at the moment.
Also blocked by the duplicate validation data, which causes the validation result to be incorrect. Is there any workaround for this issue? Like only running validation on one gpu?
@JiesiZhao077 Running in one GPU would be the easiest and safest option. You can also set this DistributedSampler
which would also solve the issue, but introduces the risk of deadlocks.
We use automatically with trainer.predict
. You could try and use it with trainer.validate
as well.
Closing this issue as there's no clear way forward at the moment.
@ananthsub commented 6 hours ago
However, the manner in which join tracks collectives can quickly run into issues with other collectives that run in the forward pass / training_step.
In PyTorch, the "with Join" construct is used as a simple wrapper around training steps. It should work in simple cases even if there are more complex cases where it doesn't work.
So, why not simply add an option to the trainer that enables wrapping the invocations of training_step
with with Join
? That should be pretty straightforward, and it would leave it up to users to determine when with Join
is the right thing to use and when it doesn't work.
So, why not simply add an option to the trainer that enables wrapping the invocations of training_step with with Join?
The join here is specific to pytorch DDP. If it was implemented, it would have to live inside the DDP plugin/strategy. For simple cases it may work, but no collective calls are allowed except the ones under DDP.forward()/DDP.backward()
if I understand correctly.
If we did want to do it "correctly", we would probably have to set throw_on_early_termination=True
and then we must handle the error in all custom collective calls, including the ones in torchmetrics. I don't know if that would work, but it's probably not feasible.
To recap, the plan would be:
Trainer(strategy=DDPStrategy(uneven_input_support: bool)
. We could also add a registry string for it.training_step
.
validation_step
and test_step
? Probably notvalidation_step
and test_step
use UnrepeatedDistributedSampler
just as trainer.predict
? Probably yes.DistributedSampler
as we wouldn't want to duplicate data to make inputs even.
UnrepeatedDistributedSampler
? Related issues:Some sources: https://pytorch.org/docs/stable/distributed.algorithms.join.html#torch.distributed.algorithms.Join https://pytorch.org/tutorials/advanced/generic_join.html
Is there a benefit to doing it for validation_step and test_step? Probably not
I assume there is, if collectives are being used. For example, sync_dist=True
in self.log
or similar. However, we don't wrap the model in ddp during val and test, so join won't be available anyways.
When the feature is enabled, we don't automatically use the generic DistributedSampler as we wouldn't want to duplicate data to make inputs even.
https://github.com/pytorch/pytorch/issues/49180 is great! Hopefully this will clarify the drop_last argument which has a slightly misleading/incomplete description :) We would indeed need the UnrepeatedDistributedSampler.
Hi, everyone, I'm gathering information on what is needed in order to support this properly.
torch.distributed.algorithms.Join
(https://pytorch.org/docs/stable/distributed.algorithms.join.html) as a context manager in which is the model run.UnrepeatedDistributedSamplerWrapper
SyncBatchNorm
) and have them as arguments to the Join
context manager from 1.self.log(..., sync_dist=True)
Is that it? cc @awaelchli, @carmocca.
@otaj Almost.
Additionally, all metrics from torchmetrics
would have to be considered as well as they are also capable of issuing syncs on their own. And in general, the user can run arbitrary syncing calls within each of the steps which have to be considered as well (which will be the trickiest part I guess)
oh, those torchmetrics
are going to be fun... :sweat_smile: I think capturing user calls can be solved with yet another contextmanager (our, custom one), what do you think?
if we can capture user calls with that, it might work similarly with torchmetrics
. So let's ignore those metrics for now and if you got a working solution for everything else, I'm sure we'll manage to integrate metrics with that :D
let's check the option with LigthingLite first 🦦
Here is the corresponding issue as suggested in planning: #14635
hi, any updates of this issue?
See more details: https://github.com/pytorch/pytorch/issues/38174
cc @borda @tchaton @rohitgr7 @akihironitta @awaelchli