Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.5k stars 3.39k forks source link

Gradient checkpointing and ddp do not work together #20395

Open rubenweitzman opened 3 weeks ago

rubenweitzman commented 3 weeks ago

Bug description

Am launching a script taht trains a model which works well when trained without ddp and using gradient checkpointing, or using ddp but no gradient checkpointing, using fabric too. However, when setting both ddp and gradient checkpointing, activate thorugh gradient_checkpointing_enable() function of huggingface, we get error

[rank0]:   File "/home/.../v2/lib/python3.10/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
[rank0]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank0]: RuntimeError: expect_autograd_hooks_ INTERNAL ASSERT FAILED at "../torch/csrc/distributed/c10d/reducer.cpp":1591, please report a bug to PyTorch. 

Scripts where launched with

fabric = Fabric(accelerator="gpu", 
                    loggers=loggers,
                    precision=opt.precision,
                    strategy=DDPStrategy(process_group_backend="nccl", find_unused_parameters=False, static_graph=True)
                    )

When i launch with options strategy=DDPStrategy(process_group_backend="nccl", find_unused_parameters=True, static_graph=False), I get error instead:

[rank0]: Parameter at index 560 with name reader.decoder.transformer.h.11.mlp.c_proj.bias has been marked as ready twice. This means that multiple autograd engine  hooks have fired for this particular parameter during this iteration.

Thanks in advance for your help.

What version are you seeing the problem on?

v2.4

How to reproduce the bug

No response

Error messages and logs

# Error messages and logs here please

Environment

Current environment ``` #- PyTorch Lightning Version (e.g., 2.4.0): #- PyTorch Version (e.g., 2.4): #- Python version (e.g., 3.12): #- OS (e.g., Linux): #- CUDA/cuDNN version: #- GPU models and configuration: #- How you installed Lightning(`conda`, `pip`, source): ```

More info

No response

lantiga commented 1 week ago

THank you @rubenweitzman, would you be able to provide a minimal reproduction? It will speed up finding a fix, thanks in advance!