pytorch / PiPPy

Pipeline Parallelism for PyTorch
BSD 3-Clause "New" or "Revised" License
725 stars 86 forks source link

[Bug?] Gradient Synchronization for DDP #1133

Open jianweif opened 3 months ago

jianweif commented 3 months ago

According to no_sync function description in https://github.com/pytorch/pytorch/blob/main/torch/nn/parallel/distributed.py#L1424

.. warning::
    The forward pass should be included inside the context manager, or
    else gradients will still be synchronized.

The current code does separate forward and backward pass in no_sync, therefore will still trigger gradient synchronization

jianweif commented 3 months ago

Code reference of forward no sync with ddp: https://github.com/pytorch/PiPPy/blob/main/pippy/_PipelineStage.py#L425 And backward no sync with ddp: https://github.com/pytorch/PiPPy/blob/main/pippy/_PipelineStage.py#L444

Since this is separate call, the no_sync will not take effect and still trigger gradient synchronization

kwen2501 commented 3 months ago

Hi thanks for reporting this.

You are right that the forward and backward passes are called in separate context managers.

It seems to me that both context managers would control the same flag: self.require_backward_grad_sync, where self here refers to the DDP module. (I found this from the code of the no_sync manager.)

Thus, I wonder if calling the forward and backward passes in separate managers might be okay? Please correct me if I missed something. Thanks!

jianweif commented 3 months ago

Thanks for your quick response! I don't have exact root cause, but see other users also reported calling fwd and bwd separately in no_sync context still triggers grad_sync. https://discuss.pytorch.org/t/whats-no-sync-exactly-do-in-ddp/170259. I am not sure if this is still an issue today so would like to confirm here.