Open carmocca opened 6 months ago
When we support and test compiling fwd-bwd-step together, we might want to reimplement this as a transform. But for the current pattern used where gradient clipping happens outside of the trace, we can simply write an ad-hoc function that the user calls.
🚀 Feature
Pitch
Port https://github.com/pytorch/pytorch/blob/c4a157086482899f0640d03292e5d2c9a6a3db68/torch/distributed/fsdp/fully_sharded_data_parallel.py#L1069-L1194 to work with Thunder's FSDP.
This could be importable through
from thunder.distributed.utils import clip_grad_norm_
. We could also move FSDP intothunder.distributed.fsdp
and put this alongside it (from thunder.distributed.fsdp import clip_grad_norm_
). Bikeshedding welcome.PyTorch docs: https://pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_norm_.html, https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_
cc @carmocca @awaelchli @crcrpar