Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
27.47k stars 3.3k forks source link

Spurious validation step when restarting with a checkpoint when `max_steps` is set in the trainer #18645

Open arnaudstiegler opened 9 months ago

arnaudstiegler commented 9 months ago

Bug description

When restarting from an existing checkpoint with a trainer that has a max_steps value set, the trainer does a single validation step before actually restarting the training epoch (even if val_sanity_checks_step is set to 0) This is an issue because that one validation step is considered a full validation round (callbacks are applied, which means the Tensorboard logger logs those spurious data points). Note that this only applies when max_steps is set, and not max_epochs so there might be something going on with the management of the global step.

Also, this problem has been reported in this discussion: https://github.com/Lightning-AI/lightning/discussions/18110

What version are you seeing the problem on?

v2.0

How to reproduce the bug

import os
from torchvision.datasets import MNIST
import os
from torch import optim, nn, utils, Tensor
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import pytorch_lightning as pl
from torch.utils.data import DataLoader
import tempfile

# define any number of nn.Modules (or use your current ones)
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))

# define the LightningModule
class LitAutoEncoder(pl.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # it is independent of forward
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        # Logging to TensorBoard (if installed) by default
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        # Logging to TensorBoard (if installed) by default
        self.log("val_loss", loss)
        return loss

    def on_validation_end(self) -> None:
        print('Ending Validation')
        return super().on_validation_end()

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

def setup_and_train(ckpt_path):
  # init the autoencoder
  autoencoder = LitAutoEncoder(encoder, decoder)

  train_dataset = MNIST(os.getcwd(), download=True, transform=ToTensor())
  test_dataset = MNIST(os.getcwd(),train=False, download=True, transform=ToTensor())
  train_loader = utils.data.DataLoader(train_dataset)
  test_loader = utils.data.DataLoader(test_dataset)

  ckpt_dir = '/content/ckpt_dir'

  checkpoint_callback = pl.callbacks.ModelCheckpoint(
          dirpath=ckpt_dir,
          save_weights_only=False,
          save_top_k=0,
          # monitor='EM',
          # mode='max',
          save_last=True,
      )
  # Initial training
  trainer = pl.Trainer(limit_train_batches=2000, max_steps=100000, num_sanity_val_steps=0, val_check_interval=1.0, callbacks=checkpoint_callback)
  trainer.fit(model=autoencoder, train_dataloaders=train_loader, val_dataloaders=test_loader, ckpt_path=ckpt_path)

I used this script changing the trainer configuration and stopping the run manually in a notebook. You can see that when restarting the run from a checkpoint with max_steps set, there's one val step run before resuming the training.


### Error messages and logs

Error messages and logs here please



### Environment

- toml:              0.10.2
    - tomli:             2.0.1
    - toolz:             0.12.0
    - torch:             2.0.1+cu118
    - torchaudio:        2.0.2+cu118
    - torchdata:         0.6.1
    - torchmetrics:      1.2.0
    - torchsummary:      1.5.1
    - torchtext:         0.15.2
    - torchvision:       0.15.2+cu118
    - tornado:           6.3.2
    - tqdm:              4.66.1
    - traitlets:         5.7.1
    - traittypes:        0.2.1
    - transformers:      4.33.2
    - triton:            2.0.0
    - tweepy:            4.13.0
    - typer:             0.9.0
    - types-setuptools:  68.2.0.0
    - typing-extensions: 4.5.0
    - tzlocal:           5.0.1
    - uc-micro-py:       1.0.2
    - uritemplate:       4.1.1
    - urllib3:           2.0.4
    - vega-datasets:     0.9.0
    - wadllib:           1.3.6
    - wasabi:            1.1.2
    - wcwidth:           0.2.6
    - webcolors:         1.13
    - webencodings:      0.5.1
    - websocket-client:  1.6.2
    - werkzeug:          2.3.7
    - wheel:             0.41.2
    - widgetsnbextension: 3.6.5
    - wordcloud:         1.9.2
    - wrapt:             1.15.0
    - xarray:            2023.7.0
    - xarray-einstats:   0.6.0
    - xgboost:           1.7.6
    - xlrd:              2.0.1
    - xxhash:            3.3.0
    - xyzservices:       2023.7.0
    - yarl:              1.9.2
    - yellowbrick:       1.5
    - yfinance:          0.2.28
    - zict:              3.0.0
    - zipp:              3.16.2
* System:
    - OS:                Linux
    - architecture:
        - 64bit
        - ELF
    - processor:         x86_64
    - python:            3.10.12
    - release:           5.15.120+
    - version:           #1 SMP Wed Aug 30 11:19:59 UTC 2023

### More info

_No response_

cc @carmocca @justusschock
roudimit commented 4 months ago

I'm experiencing the same issue.