Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
27.98k stars 3.35k forks source link

Support combinations of precision plugins #18679

Open carmocca opened 11 months ago

carmocca commented 11 months ago

Description & Motivation

Both the Fabric and Trainer strategies are designed to have a single plugin enabled from the beginning to the end of the program.

This has been fine historically, however, some strategies require tailored plugin implementations that are functionally equal to other plugins.

For instance, single device training with bf16-true precision will use the HalfPrecision plugin But FSDP training with bf16-true precision will use the FSDPPrecision plugin Recent cutting-edge plugins such as TransformerEnginePrecision and BitsandbytesPrecision also implement the basic bf16-true functionality (example). So there's a lot of overlap.

The challenge then becomes: how do I enable TransformerEnginePrecision to work with FSDPStrategy if FSDPStrategy is designed to work with FSDPPrecision only?

Note that I'm using these specific classes to prove the point, but the design issue applies to any strategy that requires a specific plugin class. DeepSpeedStrategy and XLAStrategy would also be examples of this.

Pitch

Continuing the example, there's 3 ways this could be solved:

  1. The naive way: Create a TransformerEngineFSDPPrecision. This is simple and effective but it creates maintainability problem.
  2. The independent way: If there are no dependencies between the plugins, we could support plugins=[TransformerEnginePrecision(), FSDPPrecision()]. But there will likely be dependencies.
  3. The smart way: Create an abstraction that is able to compose two (or more?) plugins together and is itself a plugin. There's some precedent for this with the CheckpointIO plugins.

Alternatives

No response

Additional context

No response

cc @borda @tchaton @justusschock @awaelchli @carmocca

naveenkumarmarri commented 11 months ago

@carmocca is there a way to use TransformerEnginePrecision with FSDP?

carmocca commented 11 months ago

Not for now: https://github.com/NVIDIA/TransformerEngine/issues/401

naveenkumarmarri commented 11 months ago

@carmocca to understand it better, where is the bottleneck? does transformer engine has to support it or FSDP has to support this?

carmocca commented 11 months ago

TransformerEngine, and then we would need to integrate whatever is changed into Lightning.

cc @sbhavani in case you know about the progress for this

sbhavani commented 11 months ago

Transformer Engine + FSDP functionally works but doesn't provide memory savings. We are working on FP8 support for PyTorch's FSDP implementation (i.e. understand FP8 tensors) upstream which would provide memory savings.

naveenkumarmarri commented 11 months ago

@sbhavani is there a timeline that you’re targeting for the feature to be available in FSDP?

On Mon, Oct 23, 2023 at 12:33 PM Santosh Bhavani @.***> wrote:

Transformer Engine + FSDP functionally works but doesn't provide memory savings. We are working on FP8 support for PyTorch's FSDP implementation (i.e. understand FP8 tensors) upstream which would provide memory savings.

— Reply to this email directly, view it on GitHub https://github.com/Lightning-AI/lightning/issues/18679#issuecomment-1775784841, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAS6OJEWZXRUNFGJTG3TAXLYA22BPAVCNFSM6AAAAAA5MWUUP6VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTONZVG44DIOBUGE . You are receiving this because you are subscribed to this thread.Message ID: @.***>

naveenkumarmarri commented 10 months ago

@carmocca any plan to support FP8 training for deepspeed strategy?

carmocca commented 6 months ago

@sbhavani I see https://github.com/NVIDIA/TransformerEngine/tree/main/examples/pytorch/fsdp exists now. Is your last comment still valid?

sbhavani commented 6 months ago

@carmocca Yes upstream support for FP8 storage and comms. in PyT's FSDP implementation is still in progress. However, with lazy weight initialization in that example, you can use FP8 with FSDP.

function2-llx commented 1 month ago

I would like to suggest another use case that two different precision plugins may be needed. Suppose my lightning module is composed of a transformer and a CNN. Is it possible to use bf16 for the transformer and 16-mixed for the CNN?

sbhavani commented 1 month ago

@carmocca it might be worth adding FSDP+FP8 support via TE. HF Accelerate just added FSDP+FP8 support as a reference: https://github.com/huggingface/accelerate/blob/main/benchmarks/fp8/fsdp.py

psr-ai commented 1 month ago

@sbhavani if I get it right, as of today, the support for FP8 + FSDP is not available with PyTorch Lightning?

lantiga commented 1 day ago

Correct. It's very much on our radar though. Anyone wants to run ahead and submit a PR?

I do think that right now:

Create a TransformerEngineFSDPPrecision. This is simple and effective but it creates maintainability problem.

is going to have the highest ROI. Once we have this fully working we can think about generalizing in the future.

Here's the starting point if anyone would like to give it a shot: https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/fabric/plugins/precision/transformer_engine.py