Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.55k stars 3.39k forks source link

PyTorch Lightning FSDP takes more memory than PyTorch FSDP #19721

Open anandhperumal opened 8 months ago

anandhperumal commented 8 months ago

Bug description

The Pytorch Lightining is taking more memory than Pytorch FSDP. I'm able to train the gemma-2b model but it takes 3 times more memory.

For openchat it goes out of memory. Please let me know if I'm missing anything.

I'm using A100 8 * 80 GB.

What version are you seeing the problem on?

v2.2

How to reproduce the bug

import torch
import torch.nn as nn

from torch.utils.data import DataLoader
from custom_dataset import SupervisedDataset
import lightning as L
from lightning.pytorch.strategies import FSDPStrategy
import transformers

class LanguageModel(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = None

    def training_step(self, batch):
        outputs = self.model(**batch)
        loss = outputs[0]
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.1)

    def configure_model(self):
        if self.model is not None:
            return
        self.model = transformers.AutoModelForCausalLM.from_pretrained(
            'openchat/openchat_3.5',
            torch_dtype=torch.bfloat16

        )

L.seed_everything(42)
tokenizer = transformers.AutoTokenizer.from_pretrained('openchat/openchat_3.5')
tokenizer.pad_token = tokenizer.eos_token
train_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=f'./train.csv', mode='train')
eval_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=f'./validation.csv', mode='validation')
test_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=f'./test.csv', mode='test')

train_dataloader = DataLoader(train_dataset, batch_size=1)

model = LanguageModel()
sharding_strategy = {}

policy = {nn.TransformerEncoderLayer, nn.TransformerDecoderLayer}
sharding_strategy['sharding_strategy'] = "FULL_SHARD"
sharding_strategy['auto_wrap_policy'] = policy
sharding_strategy['state_dict_type'] = "full"
sharding_strategy['limit_all_gathers'] = True
sharding_strategy['cpu_offload'] = True

strategy = FSDPStrategy(
    **sharding_strategy
)

trainer = L.Trainer(accelerator="cuda",  strategy=strategy, precision=16, max_epochs=1)
trainer.fit(model, train_dataloader)
# trainer.print(torch.cuda.memory_summary())
trainer.save_checkpoint("path/to/checkpoint/file")

For Pytorch FSDP Code : https://github.com/AnswerDotAI/fsdp_qlora/blob/main/train.py For pytorch FSDP : I'm using use_gradient_checkpointing: True, use_activation_cpu_offload False, use_cpu_offload False.

The context size is the same for both.

Error messages and logs

Cuda Out of Memory

Environment

Current environment ``` #- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): #- PyTorch Lightning Version (e.g., 1.5.0): 2.2.1 #- Lightning App Version (e.g., 0.5.2): #- PyTorch Version (e.g., 2.0): 2.2.1 #- Python version (e.g., 3.9): 3.9 #- OS (e.g., Linux): Linux #- CUDA/cuDNN version: 12.1.10 #- GPU models and configuration: A100 (8 * 80GB) #- How you installed Lightning(`conda`, `pip`, source): pip #- Running environment of LightningApp (e.g. local, cloud): local ```

More info

No response

cc @awaelchli @carmocca

awaelchli commented 8 months ago

Hi @anandhperumal

The reference implementation is using LoRA, but I don't see this configured anywhere in your code snippet. This will make a very big difference in memory consumption. Furthermore, you didn't enable activation checkpointing in FSDP in the code above, but you reported doing so in the reference implemenatation, which is will have another big impact. Please check again. It's important to compare equivalent settings.

If you'd like to try out LoRA with Lightning, we have an implementation here (and docs).

anandhperumal commented 8 months ago

@awaelchli so reference code also has an option "full" which train the entire model. I'm using full option.

python -u src/train.py --world_size 8 --master_port 12356 --model_name openchat/openchat-3.5 --gradient_accumulation_steps 4 --batch_size 1 --precision bf16 --train_type full --use_gradient_checkpointing false --save_model true  --use_activation_cpu_offload false --use_cpu_offload false --num_epochs 3 --lr_scheduler linear

Also, I updated the lighting activation checkpoint policy, yet no difference. Infact, I even enabled cpu_offload for pytorch lightining and not for pytorch .

sharding_strategy['activation_checkpointing_policy'] = policy

For your reference, see the below image for openchat memory consumption for pytorch FSDP code whereas lightinig doesn't even run for 1 step during training. Please let me know if I'm missing anything.

image
awaelchli commented 8 months ago

Thanks. This is a very important detail that changes everything. But there are still many differneces between the code that you shared and the reference.

I see you are specifying --precision bf16 which in Lightning the equivalent would be precision="bf16-true" (not 16 like you show). Also, the reference implementation truncates sequences to --context_length 512 by default, but this is nowhere specified in your code. Can you please share the full code you are running (replacing data with dummy random data) so that we can reproduce what you are seeing? Also I encourage you to compare with our LitGPT implementation I shared earlier.

awaelchli commented 8 months ago

Another bug in the code is policy = {nn.TransformerEncoderLayer, nn.TransformerDecoderLayer}, but you are using the openchat/openchat_3.5 from HuggingFace which doesn't have these layers, so they won't be wrapped. The reference implementation has it defined here: https://github.com/AnswerDotAI/fsdp_qlora/blob/3afe102048754c3fc499624511ab1a5ccd6ee45d/train.py#L441-L444. You should use the same.

anandhperumal commented 8 months ago

Thank you so much for getting back so quickly.

Okay, I created sample script for you. to run answer_ai.py

python answer_ai.py --world_size 8 --master_port 12356 --model_name openchat/openchat_3.5 --gradient_accumulation_steps 4 --batch_size 1 --precision bf16 --train_type full --use_gradient_checkpointing false --save_model true --use_activation_cpu_offload false --use_cpu_offload false --num_epochs 1

image

for pytorch lightining : you can directly run it no need to pass any parameters.

There are two issues: without precision="bf16-true" : Recomputed values for the following tensors have different metadata than during the forward pass. saved metadata: {'shape': torch.Size([1, 32, 228, 128]), 'dtype': torch.float32, 'device': device(type='cuda', index=3)} recomputed metadata: {'shape': torch.Size([1, 32, 456, 128]), 'dtype': torch.float32, 'device': device(type='cuda', index=3)}

and with precision="bf16-true":

saved metadata: {'shape': torch.Size([0]), 'dtype': torch.bfloat16, 'device': device(type='cpu')} recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.float32, 'device': device(type='cpu')}

script.zip

function2-llx commented 6 months ago

@awaelchli I think I find the bug. I don't find convert_module defined for FSDPPrecision. https://github.com/Lightning-AI/pytorch-lightning/blob/0c8a193d3c46f9ddba46a3ab1818e24ec41698af/src/lightning/pytorch/plugins/precision/fsdp.py#L34 And FSDPPrecision.convert_module will finally fallback to convert_module of lightning.fabric.plugins.Precision, which simply does nothing:

https://github.com/Lightning-AI/pytorch-lightning/blob/0c8a193d3c46f9ddba46a3ab1818e24ec41698af/src/lightning/fabric/plugins/precision/precision.py#L48-L54