Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.26k stars 3.38k forks source link

Steps not incremented correctly with accumulate gradients #2242

Closed mRcSchwering closed 4 years ago

mRcSchwering commented 4 years ago

πŸ› Bug

global_step and current_epoch do not match up anymore after more than 1 epoch when setting accumulate gradients greater 1. I think at the end of each epoch optimizer_step (and on_before_zero_grad) is not called in that case.

To Reproduce

  1. Create a pl.LightningModule that logs current_epoch and global_step in every training_step_end.
  2. Have a dataset with 20k samples, set the dataloader batch size to 1171, and set accumulate_grad_batches=7 in the trainer
  3. Train for 100 steps
  4. Look at logs and overall number of steps and epoch

Expected behavior

Actual behavior

Code sample

Below is basically what I have. I am adjusting the learning rate with every global step. The learning rate adjustment and each training_step_end call gets printed.

class MyModule(pl.LightningModule):

    def __init__(self, hparams: dict):
        super().__init__()
        self.hparams = hparams
        self.model = model_fact()

    def training_step_end(self, outputs: dict):
        print(f'Epoch: {self.current_epoch} Step: {self.global_step} Batch size: {len(outputs["logits"])}')

    def on_before_zero_grad(self, optimizer: torch.optim.Optimizer):
        current_lr = [d['lr'] for d in optimizer.param_groups][0]
        print(f'Step: {self.global_step} LR: {current_lr:.4e}')

    def train_dataloader(self):
        return DataLoader(MyDataset(), batch_size=1171, shuffle=False)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx: int) -> dict:
        inputs, targets = batch
        logits = self.forward(*inputs)
        loss = self.criterion(logits, targets)
        return {'loss': loss, 'logits': logits}

    def configure_optimizers(self):
        return optim.Adam(self.model.parameters())

    def optimizer_step(self, epoch: int, batch_idx: int, optimizer: optim.Optimizer,
                       optimizer_idx: int, second_order_closure: callable = None):
        # modify learning rate...
        optimizer.step()
        self.on_before_zero_grad(optimizer)
        optimizer.zero_grad()

trainer = pl.Trainer(
        max_steps=100,
        max_epochs=int(1e6),
        gpus=N_GPUS if N_GPUS > 0 else None,
        distributed_backend=None if N_GPUS < 2 else 'dp',
        num_sanity_val_steps=0,
        progress_bar_refresh_rate=0,
        accumulate_grad_batches=7,
        early_stop_callback=False)

Below is some sample output. You can see the end of the epoch where the last 93 samples are processed. Then, current_epoch increases, but global_step does not increase. Additionally, the learning rate print is missing, so on_before_zero_grad was not called.

Step: 53 LR: 3.3861e-04
Epoch: 26 Step: 54 Batch size: 1171
Epoch: 26 Step: 54 Batch size: 1171
Epoch: 26 Step: 54 Batch size: 1171
Epoch: 26 Step: 54 Batch size: 93
Epoch: 27 Step: 54 Batch size: 1171
Epoch: 27 Step: 54 Batch size: 1171
Epoch: 27 Step: 54 Batch size: 1171
Epoch: 27 Step: 54 Batch size: 1171
Epoch: 27 Step: 54 Batch size: 1171
Epoch: 27 Step: 54 Batch size: 1171
Epoch: 27 Step: 54 Batch size: 1171
Step: 54 LR: 3.3267e-04
Epoch: 27 Step: 55 Batch size: 1171
Epoch: 27 Step: 55 Batch size: 1171
Epoch: 27 Step: 55 Batch size: 1171
Epoch: 27 Step: 55 Batch size: 1171
Epoch: 27 Step: 55 Batch size: 1171
Epoch: 27 Step: 55 Batch size: 1171
Epoch: 27 Step: 55 Batch size: 1171
Step: 55 LR: 3.2673e-04
Epoch: 27 Step: 56 Batch size: 1171
Epoch: 27 Step: 56 Batch size: 1171
Epoch: 27 Step: 56 Batch size: 1171
Epoch: 27 Step: 56 Batch size: 93
Epoch: 28 Step: 56 Batch size: 1171
Epoch: 28 Step: 56 Batch size: 1171
Epoch: 28 Step: 56 Batch size: 1171
Epoch: 28 Step: 56 Batch size: 1171
Epoch: 28 Step: 56 Batch size: 1171
Epoch: 28 Step: 56 Batch size: 1171

Environment

williamFalcon commented 4 years ago

@mRcSchwering can you try 0.8.0?

mRcSchwering commented 4 years ago

Just tried on 0.8.1 (hope that's as good). Issue remains. E.g.

Epoch: 43 Step: 85 Batch size: 1171
Epoch: 43 Step: 85 Batch size: 1171
Epoch: 43 Step: 85 Batch size: 1171
Step: 85 LR: 1.4851e-04
Epoch: 43 Step: 86 Batch size: 1171
Epoch: 43 Step: 86 Batch size: 1171
Epoch: 43 Step: 86 Batch size: 1171
Epoch: 43 Step: 86 Batch size: 93
Epoch: 44 Step: 86 Batch size: 1171
Epoch: 44 Step: 86 Batch size: 1171
Epoch: 44 Step: 86 Batch size: 1171
Epoch: 44 Step: 86 Batch size: 1171
Epoch: 44 Step: 86 Batch size: 1171
Borda commented 4 years ago

@mRcSchwering mind check it with our latest master? 🐰

ydcjeff commented 4 years ago

I guess this is solved by https://github.com/PyTorchLightning/pytorch-lightning/pull/2853 and i could reproduce the expected behavior, can you please confirm this? @mRcSchwering Running script to be compatible with current master

import os

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl

pl.seed_everything(666)

class MyModule(pl.LightningModule):

    def __init__(self, hparams: dict):
        super().__init__()
        self.hparams = hparams
        self.model = nn.Linear(28*28, 10)

    def training_step_end(self, outputs: dict):
        print(f'Epoch: {self.current_epoch} Step: {self.global_step} Batch size: {len(outputs["logits"])}')
        return outputs

    def on_before_zero_grad(self, optimizer: torch.optim.Optimizer):
        current_lr = [d['lr'] for d in optimizer.param_groups][0]
        print(f'Step: {self.global_step} LR: {current_lr:.4e}')

    def train_dataloader(self):
        return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=1171, shuffle=False)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx: int) -> dict:
        inputs, targets = batch
        logits = self.forward(inputs.view(inputs.size(0), -1))
        loss = F.cross_entropy(logits, targets)
        return {'loss': loss, 'logits': logits}

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=3e-4)

    def optimizer_step(self, epoch, batch_idx, optimizer, opt_idx, lambda_closure, using_native_amp, using_lbfgs):
        # modify learning rate...
        optimizer.step()
        self.on_before_zero_grad(optimizer)
        optimizer.zero_grad()

trainer = pl.Trainer(
        max_steps=100,
        max_epochs=int(1e6),
        gpus=-1,
        num_sanity_val_steps=0,
        progress_bar_refresh_rate=0,
        accumulate_grad_batches=7,
        early_stop_callback=False)
model = MyModule({})
trainer.fit(model)

Output: LR is printed every 7 accumulate steps and also in the last batch. current_epoch and global_step too incremented.

Epoch: 1 Step: 8 Batch size: 1171
Epoch: 1 Step: 8 Batch size: 1171
Epoch: 1 Step: 8 Batch size: 1171
Epoch: 1 Step: 8 Batch size: 1171
Epoch: 1 Step: 8 Batch size: 1171
Epoch: 1 Step: 8 Batch size: 1171
Epoch: 1 Step: 8 Batch size: 1171
Step: 8 LR: 3.0000e-04
Step: 8 LR: 3.0000e-04
Epoch: 1 Step: 9 Batch size: 1171
Epoch: 1 Step: 9 Batch size: 1171
Epoch: 1 Step: 9 Batch size: 1171
Epoch: 1 Step: 9 Batch size: 1171
Epoch: 1 Step: 9 Batch size: 1171
Epoch: 1 Step: 9 Batch size: 1171
Epoch: 1 Step: 9 Batch size: 1171
Step: 9 LR: 3.0000e-04
Step: 9 LR: 3.0000e-04
Epoch: 1 Step: 10 Batch size: 1171
Epoch: 1 Step: 10 Batch size: 1171
Epoch: 1 Step: 10 Batch size: 1171
Epoch: 1 Step: 10 Batch size: 1171
Epoch: 1 Step: 10 Batch size: 1171
Epoch: 1 Step: 10 Batch size: 1171
Epoch: 1 Step: 10 Batch size: 1171
Step: 10 LR: 3.0000e-04
Step: 10 LR: 3.0000e-04
Epoch: 1 Step: 11 Batch size: 1171
Epoch: 1 Step: 11 Batch size: 1171
Epoch: 1 Step: 11 Batch size: 1171
Epoch: 1 Step: 11 Batch size: 1171
Epoch: 1 Step: 11 Batch size: 1171
Epoch: 1 Step: 11 Batch size: 1171
Epoch: 1 Step: 11 Batch size: 1171
Step: 11 LR: 3.0000e-04
Step: 11 LR: 3.0000e-04
Epoch: 1 Step: 12 Batch size: 1171
Epoch: 1 Step: 12 Batch size: 1171
Epoch: 1 Step: 12 Batch size: 1171
Epoch: 1 Step: 12 Batch size: 1171
Epoch: 1 Step: 12 Batch size: 1171
Epoch: 1 Step: 12 Batch size: 1171
Epoch: 1 Step: 12 Batch size: 1171
Step: 12 LR: 3.0000e-04
Step: 12 LR: 3.0000e-04
Epoch: 1 Step: 13 Batch size: 1171
Epoch: 1 Step: 13 Batch size: 1171
Epoch: 1 Step: 13 Batch size: 1171
Epoch: 1 Step: 13 Batch size: 1171
Epoch: 1 Step: 13 Batch size: 1171
Epoch: 1 Step: 13 Batch size: 1171
Epoch: 1 Step: 13 Batch size: 1171
Step: 13 LR: 3.0000e-04
Step: 13 LR: 3.0000e-04
Epoch: 1 Step: 14 Batch size: 1171
Epoch: 1 Step: 14 Batch size: 1171
Epoch: 1 Step: 14 Batch size: 1171
Epoch: 1 Step: 14 Batch size: 1171
Epoch: 1 Step: 14 Batch size: 1171
Epoch: 1 Step: 14 Batch size: 1171
Epoch: 1 Step: 14 Batch size: 1171
Step: 14 LR: 3.0000e-04
Step: 14 LR: 3.0000e-04
Epoch: 1 Step: 15 Batch size: 1171
Epoch: 1 Step: 15 Batch size: 1171
Epoch: 1 Step: 15 Batch size: 279
Step: 15 LR: 3.0000e-04
Step: 15 LR: 3.0000e-04
Epoch: 2 Step: 16 Batch size: 1171
Epoch: 2 Step: 16 Batch size: 1171
Epoch: 2 Step: 16 Batch size: 1171
Epoch: 2 Step: 16 Batch size: 1171
Epoch: 2 Step: 16 Batch size: 1171
Epoch: 2 Step: 16 Batch size: 1171
Epoch: 2 Step: 16 Batch size: 1171
Step: 16 LR: 3.0000e-04
Step: 16 LR: 3.0000e-04
mRcSchwering commented 4 years ago

Cool, thx. And I learned something new (seed_everything)

ydcjeff commented 4 years ago

so I guess this could be closed.