Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.35k stars 3.39k forks source link

OneCycleLR scheduler state not restored from checkpoint file #15462

Closed vhewes closed 1 year ago

vhewes commented 2 years ago

Bug description

i recently adapted a network architecture to a LightningModule, and find that when resuming a training in progress from a checkpoint file, the state of the OneCycleLR scheduler is not properly restored. i've tested with version 1.8.0 and confirmed that the issue persists.

the example pasted below will run a full end-to-end training of a toy model and dataset, then train the same model halfway to completion, and finally load the checkpoint file from the halfway-trained model and train it the rest of the way. the example will plot the learning rate in both cases, and demonstrate that in the latter case, the learning rate scheduler's internal state is not restored successfully when loading from the checkpoint file.

How to reproduce the bug

from typing import List
import torch
import pytorch_lightning as pl
import pandas as pd
import matplotlib.pyplot as plt

class ToyDataset(torch.utils.data.Dataset):
    def __init__(self):
        super(ToyDataset, self).__init__()

    def __len__(self) -> int:
        return 1000

    def __getitem__(self, idx: int) -> List[torch.Tensor]:
        return [ torch.rand(5), torch.randint(high=2, size=[1]) ]

class ToyModel(pl.LightningModule):
    def __init__(self):
        super(ToyModel, self).__init__()

        self.net = torch.nn.Sequential(
            torch.nn.Linear(in_features=5,
                            out_features=1),
            torch.nn.Sigmoid())

        self.hparams.learning_rate = 0.1

        self.loss_func = torch.nn.BCELoss()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

    def training_step(self,
                      batch: torch.Tensor,
                      batch_idx: int) -> float:
        x, y = batch
        x = self(x)
        loss = self.loss_func(x, y.float())
        self.log('learning rate', self.optimizers().state_dict()['param_groups'][0]['lr'])
        return loss

    def configure_optimizers(self) -> tuple:
        print(self.trainer.estimated_stepping_batches)
        optimizer = torch.optim.SGD(self.parameters(),
                                    lr=self.hparams.learning_rate,
                                    momentum=0.9)
        onecycle = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=self.hparams.learning_rate,
            final_div_factor=1e6,
            total_steps=100)
        return [optimizer], {'scheduler': onecycle, 'interval': 'step'}

def main():

    data = torch.utils.data.DataLoader(ToyDataset(), batch_size=200)
    model = ToyModel()

    # train all the way
    logger = pl.loggers.CSVLogger(save_dir='.',
                                  name='toy',
                                  version='full')
    trainer = pl.Trainer(accelerator='cpu',
                         max_epochs=20,
                         log_every_n_steps=1,
                         logger=logger)
    trainer.fit(model, train_dataloaders=data)

    # train halfway
    logger = pl.loggers.CSVLogger(save_dir='.',
                                  name='toy',
                                  version='halfway')
    trainer = pl.Trainer(accelerator='cpu',
                         max_epochs=10,
                         log_every_n_steps=1,
                         logger=logger)
    trainer.fit(model, train_dataloaders=data)

    # load checkpoint and train the rest of the way
    logger = pl.loggers.CSVLogger(save_dir='.',
                                   name='toy',
                                   version='resumed')
    trainer = pl.Trainer(accelerator='cpu',
                         max_epochs=20,
                         log_every_n_steps=1,
                         logger=logger)
    trainer.fit(model,
                train_dataloaders=data,
                ckpt_path='toy/halfway/checkpoints/epoch=9-step=50.ckpt')

    # compare the learning rates of the full and resumed training instances
    full = pd.read_csv('toy/full/metrics.csv')
    halfway = pd.read_csv('toy/halfway/metrics.csv')
    resumed = pd.read_csv('toy/resumed/metrics.csv')
    combined = pd.concat([halfway, resumed])
    ax = full.plot(x='step', y='learning rate', ylabel='learning rate', label='full')
    combined.plot(x='step', y='learning rate', label='resumed', ax=ax)
    plt.show()

if __name__ == '__main__':

    main()

Error messages and logs

No response

Environment

More info

the environment provided above is my local machine, where i constructed the toy, but i also observe the same issue on an Nvidia GPU cluster and in HPC environments, so it is not localised to a specific architecture.

cc @rohitgr7

rohitgr7 commented 2 years ago

okay.. I know the issue just to unblock you, can you use

self.log('learning rate', self.trainer.optimizers[0].state_dict()['param_groups'][0]['lr'])
awaelchli commented 1 year ago

This was fixed in https://github.com/Lightning-AI/lightning/pull/18280 See my full reply here on another issue: https://github.com/Lightning-AI/lightning/issues/17296#issuecomment-1726715614