Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.12k stars 3.37k forks source link

Handle gradient accumulations at the end of epoch differently #19987

Open jakub-dyno opened 3 months ago

jakub-dyno commented 3 months ago

Bug description

At the end of an epoch with accumulate_grad_batches>1 the dataloader may run out of data before the required number of accumulations. The lightning docs do not say what happens. It could

  1. not update the gradients
  2. update gradients correctly but with an effectively smaller batch size
  3. update gradients incorrectly because the gradients are scaled by accumulate_grad_batches instead of the actual number of accumulations

My experiments suggest its option 3 but happy to be wrong.

image

What version are you seeing the problem on?

v2.0, v2.2

How to reproduce the bug

import torch
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader, TensorDataset
import pytorch_lightning as pl

# Generate some dummy data
X = torch.randn(1000, 28*28)  # 1000 samples of 28*28 features
y = torch.randint(0, 10, (1000,))  # 1000 labels for 10 classes

# Create a TensorDataset
dataset = TensorDataset(X, y)
# Create DataLoaders
train_loader = DataLoader(dataset, batch_size=1, shuffle=True)

class SimpleModel(pl.LightningModule):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.layer_1 = torch.nn.Linear(28 * 28, 128)
        self.layer_2 = torch.nn.Linear(128, 10)
        self.loss = 0
        self.steps = 0

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = torch.relu(self.layer_1(x))
        return torch.log_softmax(self.layer_2(x), dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = torch.nn.functional.nll_loss(y_hat, y)
        self.loss += loss.detach()
        self.steps += 1
        self.log('train_loss_batch', loss.detach(), on_step=True, on_epoch=False, logger=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

    def on_before_optimizer_step(self, optimizer):
        self.log('train_loss', self.loss/self.steps, prog_bar=True, logger=True)
        self.log('accumulations', self.steps, prog_bar=False, logger=True)
        self.loss = 0
        self.steps = 0

        grad_norm = 0
        for param in self.parameters():
            if param.grad is not None:
                grad_norm += param.grad.norm(2).item()
        self.log('grad norm', grad_norm, prog_bar=False, logger=True)

        return super().on_before_optimizer_step(optimizer)

model = SimpleModel()
trainer = pl.Trainer(max_epochs=5, accumulate_grad_batches=32, log_every_n_steps=1)
trainer.fit(model, train_loader)

Error messages and logs

# Error messages and logs here please

Environment

Current environment ``` #- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): #- PyTorch Lightning Version (e.g., 1.5.0): #- Lightning App Version (e.g., 0.5.2): #- PyTorch Version (e.g., 2.0): #- Python version (e.g., 3.9): #- OS (e.g., Linux): #- CUDA/cuDNN version: #- GPU models and configuration: #- How you installed Lightning(`conda`, `pip`, source): #- Running environment of LightningApp (e.g. local, cloud): ```

More info

No response

cc @borda

awaelchli commented 3 months ago

@jakub-dyno If the epoch size is not evenly divisible by the gradient accumulation size, the optimizer will step anyway on the last iterations. The partial accumulation can't be kept around in memory and continued in the next epoch.

The loss over the accumulation window is scaled regardless of whether it is the full window or not: https://github.com/Lightning-AI/pytorch-lightning/blob/e330da5870fae34339170b942095a2600fa7a95e/src/lightning/pytorch/loops/optimization/automatic.py#L327

https://github.com/Lightning-AI/pytorch-lightning/blob/e330da5870fae34339170b942095a2600fa7a95e/src/lightning/pytorch/loops/optimization/automatic.py#L81

If you'd like to change it, a PR for this would be welcome. But it could be a bit involved.