Closed ananthsub closed 1 year ago
Let me take this :)
Any updates on this issue ?
This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!
Proposed refactoring or deprecation
Motivation
Lightning has a utility defined for all gather with gradients here: https://github.com/PyTorchLightning/pytorch-lightning/blob/d515bcac969c2a485ada673e302bfac51f142331/pytorch_lightning/utilities/distributed.py#L200-L222
However, this is already available in torch distributed: https://github.com/pytorch/pytorch/blob/6b44e75f6bccca7acc8ec31a635f1175c265ac54/torch/distributed/nn/functional.py#L82-L94
So there's no need to redefine this in Lightning
Pitch
Remove the custom all gather grad implementation and call torch distributed's functional API
Additional context
If you enjoy Lightning, check out our other projects! ⚡
Metrics: Machine learning metrics for distributed, scalable PyTorch applications.
Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.
Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.
Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.
Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.
cc @borda @awaelchli @rohitgr7 @akihironitta @justusschock @tchaton