Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.15k stars 77 forks source link

[FSDP] Support gradient clipping by norm #309

Open carmocca opened 6 months ago

carmocca commented 6 months ago

🚀 Feature

Pitch

Port https://github.com/pytorch/pytorch/blob/c4a157086482899f0640d03292e5d2c9a6a3db68/torch/distributed/fsdp/fully_sharded_data_parallel.py#L1069-L1194 to work with Thunder's FSDP.

This could be importable through from thunder.distributed.utils import clip_grad_norm_. We could also move FSDP into thunder.distributed.fsdp and put this alongside it (from thunder.distributed.fsdp import clip_grad_norm_). Bikeshedding welcome.

PyTorch docs: https://pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_norm_.html, https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_

cc @carmocca @awaelchli @crcrpar

carmocca commented 6 months ago

When we support and test compiling fwd-bwd-step together, we might want to reimplement this as a transform. But for the current pattern used where gradient clipping happens outside of the trace, we can simply write an ad-hoc function that the user calls.