Open anthonyhu opened 1 year ago
My hunch here is that pre 2.0 it was using 16-bit precision by default and 2.0 is using 32 bit. You should be able to verify this by manually setting the precision with 2.0 installed
Also, I suggest that you check out https://github.com/Lightning-AI/lit-gpt for all your gpt training needs. The one you linked is an earlier version that hasn't been updated
Thanks for the answer! Both models are trained with 16bit precision which is set there: https://github.com/Lightning-Universe/lit-GPT/blob/main/train.py#L90
Could it be related to moving from fairscale to native implementation by default?
Thanks for the answer! Both models are trained with 16bit precision which is set there: https://github.com/Lightning-Universe/lit-GPT/blob/main/train.py#L90
Could it be related to moving from fairscale to native implementation by default?
Another idea for future observers is that this setting may not be reaching FSDP which has its own MixedPrecision settings object?
Bug description
Hello! Thank you for the integration of fsdp in the lightning trainer - it's a game changer.
I tried to switch from
lightning==1.9.4
to the newestlightning==2.0.4
but observed a significant slow down in training in yourlit-GPT
repository.Previously, with
lightning==1.9.4
andtorch==2.0.1
running the commandpython train.py --implementation nanogpt --batch_size 1 --block_size 8192 --strategy fsdp-gpt --model_type gpt2-xl
the training speed would be 4.04s/iteration and GPU memory consumption 4.6GB (on a 8x A100 80GB gpus machine).When updating to
lightning==2.0.4
(and same version of torch), I had to switch from using the defaultfrom pytorch_lightning.strategies.fully_sharded_native import DDPFullyShardedNativeStrategy
topl.strategies.FSDPStrategy
here https://github.com/Lightning-Universe/lit-GPT/blob/main/lightning_gpt/models.py#L287 Running the same command I now get: 9.52s/it and 6.7B memory consumption.That's a 2x training slowdown and 47% memory increase 😢 Would you have any idea why?
What version are you seeing the problem on?
v2.0
How to reproduce the bug
Error messages and logs
Environment
Current environment
``` #- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): #- PyTorch Lightning Version (e.g., 1.5.0): 2.0.4 #- Lightning App Version (e.g., 0.5.2): #- PyTorch Version (e.g., 2.0): 2.0.1 #- Python version (e.g., 3.9): 3.9 #- OS (e.g., Linux): Linux #- CUDA/cuDNN version: #- GPU models and configuration: 8x A100 80GB #- How you installed Lightning(`conda`, `pip`, source): pip in a fresh conda environment #- Running environment of LightningApp (e.g. local, cloud): ```More info
No response
cc @borda @awaelchli @carmocca