Open carmocca opened 1 year ago
@carmocca is there a way to use TransformerEnginePrecision
with FSDP?
@carmocca to understand it better, where is the bottleneck? does transformer engine has to support it or FSDP has to support this?
TransformerEngine, and then we would need to integrate whatever is changed into Lightning.
cc @sbhavani in case you know about the progress for this
Transformer Engine + FSDP functionally works but doesn't provide memory savings. We are working on FP8 support for PyTorch's FSDP implementation (i.e. understand FP8 tensors) upstream which would provide memory savings.
@sbhavani is there a timeline that you’re targeting for the feature to be available in FSDP?
On Mon, Oct 23, 2023 at 12:33 PM Santosh Bhavani @.***> wrote:
Transformer Engine + FSDP functionally works but doesn't provide memory savings. We are working on FP8 support for PyTorch's FSDP implementation (i.e. understand FP8 tensors) upstream which would provide memory savings.
— Reply to this email directly, view it on GitHub https://github.com/Lightning-AI/lightning/issues/18679#issuecomment-1775784841, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAS6OJEWZXRUNFGJTG3TAXLYA22BPAVCNFSM6AAAAAA5MWUUP6VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTONZVG44DIOBUGE . You are receiving this because you are subscribed to this thread.Message ID: @.***>
@carmocca any plan to support FP8 training for deepspeed strategy?
@sbhavani I see https://github.com/NVIDIA/TransformerEngine/tree/main/examples/pytorch/fsdp exists now. Is your last comment still valid?
@carmocca Yes upstream support for FP8 storage and comms. in PyT's FSDP implementation is still in progress. However, with lazy weight initialization in that example, you can use FP8 with FSDP.
I would like to suggest another use case that two different precision plugins may be needed. Suppose my lightning module is composed of a transformer and a CNN. Is it possible to use bf16 for the transformer and 16-mixed for the CNN?
@carmocca it might be worth adding FSDP+FP8 support via TE. HF Accelerate just added FSDP+FP8 support as a reference: https://github.com/huggingface/accelerate/blob/main/benchmarks/fp8/fsdp.py
@sbhavani if I get it right, as of today, the support for FP8 + FSDP is not available with PyTorch Lightning?
Correct. It's very much on our radar though. Anyone wants to run ahead and submit a PR?
I do think that right now:
Create a TransformerEngineFSDPPrecision. This is simple and effective but it creates maintainability problem.
is going to have the highest ROI. Once we have this fully working we can think about generalizing in the future.
Here's the starting point if anyone would like to give it a shot: https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/fabric/plugins/precision/transformer_engine.py
Description & Motivation
Both the Fabric and Trainer strategies are designed to have a single plugin enabled from the beginning to the end of the program.
This has been fine historically, however, some strategies require tailored plugin implementations that are functionally equal to other plugins.
For instance, single device training with
bf16-true
precision will use theHalfPrecision
plugin But FSDP training withbf16-true
precision will use theFSDPPrecision
plugin Recent cutting-edge plugins such asTransformerEnginePrecision
andBitsandbytesPrecision
also implement the basicbf16-true
functionality (example). So there's a lot of overlap.The challenge then becomes: how do I enable
TransformerEnginePrecision
to work withFSDPStrategy
ifFSDPStrategy
is designed to work withFSDPPrecision
only?Note that I'm using these specific classes to prove the point, but the design issue applies to any strategy that requires a specific plugin class.
DeepSpeedStrategy
andXLAStrategy
would also be examples of this.Pitch
Continuing the example, there's 3 ways this could be solved:
TransformerEngineFSDPPrecision
. This is simple and effective but it creates maintainability problem.plugins=[TransformerEnginePrecision(), FSDPPrecision()]
. But there will likely be dependencies.CheckpointIO
plugins.Alternatives
No response
Additional context
No response
cc @borda @tchaton @justusschock @awaelchli @carmocca