Open function2-llx opened 1 year ago
Gradients seem not synchronized with manual optimization and DDPStrategy with static_graph=True.
static_graph=True
v2.0.5
Create main.py, and run python main.py fit with two GPUs:
main.py
python main.py fit
# main.py from lightning.pytorch.cli import LightningCLI from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.strategies import DDPStrategy class MyModel(BoringModel): def __init__(self): super().__init__() self.automatic_optimization = False def training_step(self, batch, batch_idx: int): optimizer = self.optimizers() optimizer.zero_grad() out = super().training_step(batch, batch_idx) loss = out['loss'] self.manual_backward(loss) optimizer.step() def on_train_batch_end(self, *args, **kwargs): print(self.layer.bias.grad, f'[rank {self.global_rank}] grad') print(self.layer.bias, f'[rank {self.global_rank}]') def main(): LightningCLI( MyModel, save_config_kwargs={'overwrite': True}, trainer_defaults={ 'strategy': DDPStrategy(static_graph=True), 'max_steps': 1, 'enable_progress_bar': False, }, seed_everything_default=42, ) if __name__ == '__main__': main()
The outputs of the script above are as follows, the gradients are not synchronized.:
tensor([-1.8170, -1.3621], device='cuda:0') [rank 0] grad tensor([-1.1449, -2.1265], device='cuda:1') [rank 1] grad Parameter containing: tensor([0.2359, 0.0994], device='cuda:0', requires_grad=True) [rank 0] Parameter containing: tensor([0.1687, 0.1758], device='cuda:1', requires_grad=True) [rank 1]
When setting static_graph=False or using automatic optimization, the outputs are as follows, the gradients are synchronized:
static_graph=False
tensor([-1.4809, -1.7443], device='cuda:0') [rank 0] grad tensor([-1.4809, -1.7443], device='cuda:1') [rank 1] grad Parameter containing: tensor([0.2023, 0.1376], device='cuda:0', requires_grad=True) [rank 0] Parameter containing: tensor([0.2023, 0.1376], device='cuda:1', requires_grad=True) [rank 1]
No response
I experience the same with pytorch-lightning==2.3.0. I think it might be caused by this line.
pytorch-lightning==2.3.0
Bug description
Gradients seem not synchronized with manual optimization and DDPStrategy with
static_graph=True
.What version are you seeing the problem on?
v2.0.5
How to reproduce the bug
Create
main.py
, and runpython main.py fit
with two GPUs:Error messages and logs
The outputs of the script above are as follows, the gradients are not synchronized.:
When setting
static_graph=False
or using automatic optimization, the outputs are as follows, the gradients are synchronized:Environment
Current environment
``` #- PyTorch Lightning Version (e.g., 1.5.0): 2.0.5 #- PyTorch Version (e.g., 2.0): 2.0.1 #- Python version (e.g., 3.9): 3.11.4 #- OS (e.g., Linux): Linux #- How you installed Lightning(`conda`, `pip`, source): pip ```More info
No response