Investigate FSDP + CPU Offload performance in Trainer #18336

Open awaelchli opened 1 year ago

awaelchli commented 1 year ago

Bug description

When writing the new FSDP guide for Trainer in #18326, I got suspiciously slow iteration speed when enabling CPU offload (see

Iterations per second Fabric FSDP + Offload: 0.3 Trainer FSDP + Offload: 0.02

What version are you seeing the problem on?


How to reproduce the bug

import torch
import torch.nn as nn
import torch.nn.functional as F
from import DataLoader

import lightning as L
from lightning.pytorch.strategies import FSDPStrategy
from lightning.pytorch.demos import Transformer, WikiText2

class LanguageModel(L.LightningModule):
    def __init__(self, vocab_size):
        self.vocab_size = vocab_size
        self.model = None

    def configure_model(self):
        self.model = self.model or Transformer(  # 1B parameters
            vocab_size=self.vocab_size, nlayers=32, nhid=4096, ninp=1024, nhead=64

    def training_step(self, batch):
        input, target = batch
        output = self.model(input, target)
        loss = F.nll_loss(output, target.view(-1))
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.1)


# Data
dataset = WikiText2()
train_dataloader = DataLoader(dataset)

# Model
model = LanguageModel(vocab_size=dataset.vocab_size)

policy = {nn.TransformerEncoderLayer, nn.TransformerDecoderLayer}
strategy = FSDPStrategy(

# Trainer
trainer = L.Trainer(accelerator="cuda", devices=2, strategy=strategy), train_dataloader)

Error messages and logs

No errors.


More info

cc @borda @awaelchli @carmocca

carmocca commented 1 year ago

Do you observe the same results with Fabric?

awaelchli commented 1 year ago

The reason why I opened the issue is precisely because the difference to Fabric is so noticeable. The numbers are in the description above and in the docs pages.