Open anandhperumal opened 8 months ago
Hi @anandhperumal
The reference implementation is using LoRA, but I don't see this configured anywhere in your code snippet. This will make a very big difference in memory consumption. Furthermore, you didn't enable activation checkpointing in FSDP in the code above, but you reported doing so in the reference implemenatation, which is will have another big impact. Please check again. It's important to compare equivalent settings.
If you'd like to try out LoRA with Lightning, we have an implementation here (and docs).
@awaelchli so reference code also has an option "full" which train the entire model. I'm using full option.
python -u src/train.py --world_size 8 --master_port 12356 --model_name openchat/openchat-3.5 --gradient_accumulation_steps 4 --batch_size 1 --precision bf16 --train_type full --use_gradient_checkpointing false --save_model true --use_activation_cpu_offload false --use_cpu_offload false --num_epochs 3 --lr_scheduler linear
Also, I updated the lighting activation checkpoint policy, yet no difference. Infact, I even enabled cpu_offload for pytorch lightining and not for pytorch .
sharding_strategy['activation_checkpointing_policy'] = policy
For your reference, see the below image for openchat memory consumption for pytorch FSDP code whereas lightinig doesn't even run for 1 step during training. Please let me know if I'm missing anything.
Thanks. This is a very important detail that changes everything. But there are still many differneces between the code that you shared and the reference.
I see you are specifying --precision bf16
which in Lightning the equivalent would be precision="bf16-true"
(not 16 like you show). Also, the reference implementation truncates sequences to --context_length 512
by default, but this is nowhere specified in your code. Can you please share the full code you are running (replacing data with dummy random data) so that we can reproduce what you are seeing? Also I encourage you to compare with our LitGPT implementation I shared earlier.
Another bug in the code is policy = {nn.TransformerEncoderLayer, nn.TransformerDecoderLayer}
, but you are using the openchat/openchat_3.5
from HuggingFace which doesn't have these layers, so they won't be wrapped. The reference implementation has it defined here: https://github.com/AnswerDotAI/fsdp_qlora/blob/3afe102048754c3fc499624511ab1a5ccd6ee45d/train.py#L441-L444. You should use the same.
Thank you so much for getting back so quickly.
Okay, I created sample script for you. to run answer_ai.py
python answer_ai.py --world_size 8 --master_port 12356 --model_name openchat/openchat_3.5 --gradient_accumulation_steps 4 --batch_size 1 --precision bf16 --train_type full --use_gradient_checkpointing false --save_model true --use_activation_cpu_offload false --use_cpu_offload false --num_epochs 1
for pytorch lightining : you can directly run it no need to pass any parameters.
There are two issues:
without precision="bf16-true" : Recomputed values for the following tensors have different metadata than during the forward pass. saved metadata: {'shape': torch.Size([1, 32, 228, 128]), 'dtype': torch.float32, 'device': device(type='cuda', index=3)} recomputed metadata: {'shape': torch.Size([1, 32, 456, 128]), 'dtype': torch.float32, 'device': device(type='cuda', index=3)}
and with precision="bf16-true":
saved metadata: {'shape': torch.Size([0]), 'dtype': torch.bfloat16, 'device': device(type='cpu')} recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.float32, 'device': device(type='cpu')}
@awaelchli I think I find the bug. I don't find convert_module
defined for FSDPPrecision
.
https://github.com/Lightning-AI/pytorch-lightning/blob/0c8a193d3c46f9ddba46a3ab1818e24ec41698af/src/lightning/pytorch/plugins/precision/fsdp.py#L34
And FSDPPrecision.convert_module
will finally fallback to convert_module
of lightning.fabric.plugins.Precision
, which simply does nothing:
Bug description
The Pytorch Lightining is taking more memory than Pytorch FSDP. I'm able to train the gemma-2b model but it takes 3 times more memory.
For openchat it goes out of memory. Please let me know if I'm missing anything.
I'm using A100 8 * 80 GB.
What version are you seeing the problem on?
v2.2
How to reproduce the bug
For Pytorch FSDP Code : https://github.com/AnswerDotAI/fsdp_qlora/blob/main/train.py For pytorch FSDP : I'm using use_gradient_checkpointing: True, use_activation_cpu_offload False, use_cpu_offload False.
The context size is the same for both.
Error messages and logs
Environment
Current environment
``` #- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): #- PyTorch Lightning Version (e.g., 1.5.0): 2.2.1 #- Lightning App Version (e.g., 0.5.2): #- PyTorch Version (e.g., 2.0): 2.2.1 #- Python version (e.g., 3.9): 3.9 #- OS (e.g., Linux): Linux #- CUDA/cuDNN version: 12.1.10 #- GPU models and configuration: A100 (8 * 80GB) #- How you installed Lightning(`conda`, `pip`, source): pip #- Running environment of LightningApp (e.g. local, cloud): local ```More info
No response
cc @awaelchli @carmocca