Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.02k stars 3.36k forks source link

2x slower training speed with FSDP when switching from lightning 1.9 to 2.0 #18028

Open anthonyhu opened 1 year ago

anthonyhu commented 1 year ago

Bug description

Hello! Thank you for the integration of fsdp in the lightning trainer - it's a game changer.

I tried to switch from lightning==1.9.4 to the newest lightning==2.0.4 but observed a significant slow down in training in your lit-GPT repository.

Previously, with lightning==1.9.4 and torch==2.0.1 running the command python train.py --implementation nanogpt --batch_size 1 --block_size 8192 --strategy fsdp-gpt --model_type gpt2-xl the training speed would be 4.04s/iteration and GPU memory consumption 4.6GB (on a 8x A100 80GB gpus machine).

When updating to lightning==2.0.4 (and same version of torch), I had to switch from using the default from pytorch_lightning.strategies.fully_sharded_native import DDPFullyShardedNativeStrategy to pl.strategies.FSDPStrategy here https://github.com/Lightning-Universe/lit-GPT/blob/main/lightning_gpt/models.py#L287 Running the same command I now get: 9.52s/it and 6.7B memory consumption.

That's a 2x training slowdown and 47% memory increase 😢 Would you have any idea why?

What version are you seeing the problem on?

v2.0

How to reproduce the bug

python train.py --implementation nanogpt --batch_size 1 --block_size 8192 --strategy fsdp-gpt --model_type gpt2-xl

Error messages and logs

# Error messages and logs here please

Environment

Current environment ``` #- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): #- PyTorch Lightning Version (e.g., 1.5.0): 2.0.4 #- Lightning App Version (e.g., 0.5.2): #- PyTorch Version (e.g., 2.0): 2.0.1 #- Python version (e.g., 3.9): 3.9 #- OS (e.g., Linux): Linux #- CUDA/cuDNN version: #- GPU models and configuration: 8x A100 80GB #- How you installed Lightning(`conda`, `pip`, source): pip in a fresh conda environment #- Running environment of LightningApp (e.g. local, cloud): ```

More info

No response

cc @borda @awaelchli @carmocca

carmocca commented 1 year ago

My hunch here is that pre 2.0 it was using 16-bit precision by default and 2.0 is using 32 bit. You should be able to verify this by manually setting the precision with 2.0 installed

carmocca commented 1 year ago

Also, I suggest that you check out https://github.com/Lightning-AI/lit-gpt for all your gpt training needs. The one you linked is an earlier version that hasn't been updated

anthonyhu commented 1 year ago

Thanks for the answer! Both models are trained with 16bit precision which is set there: https://github.com/Lightning-Universe/lit-GPT/blob/main/train.py#L90

Could it be related to moving from fairscale to native implementation by default?

zaptrem commented 7 months ago

Thanks for the answer! Both models are trained with 16bit precision which is set there: https://github.com/Lightning-Universe/lit-GPT/blob/main/train.py#L90

Could it be related to moving from fairscale to native implementation by default?

Another idea for future observers is that this setting may not be reaching FSDP which has its own MixedPrecision settings object?