Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.01k stars 3.36k forks source link

Validation runs only for one iteration when restarting from checkpoint mid-epoch, wrongly reporting validation loss #19549

Open pimdh opened 6 months ago

pimdh commented 6 months ago

Bug description

When resuming from a mid-epoch checkpoint (which I have to use as my dataset is large), the training loop runs a validation loop for only one iteration, which leads to wrong validation loss logged.

It appears like the batch_progress of lighting.pytorch.loops._EvaluationLoop wrongly gets filled from the checkpoint as if the validation loop was already done, and not properly reset after the checkpoint is loaded.

What version are you seeing the problem on?

v2.2

How to reproduce the bug

import os

import torch
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader, Dataset

class RandomDataset(Dataset):
    def __init__(self, size, num_samples):
        self.len = num_samples
        self.data = torch.randn(num_samples, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len

class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        print("Training", batch.shape, loss.item(), batch_idx)
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        print("Validation", batch.shape, loss.item())
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)

    def on_train_epoch_end(self) -> None:
        return super().on_train_epoch_end()

def run(ckpt_path):
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=4,
        limit_val_batches=3,
        val_check_interval=2,
        max_epochs=1,
        enable_model_summary=False,
        enable_progress_bar=False,
        num_sanity_val_steps=0,
        logger=False,
        callbacks=[
            ModelCheckpoint(
                save_last=False,
                save_top_k=10,
                monitor="valid_loss",
                every_n_train_steps=1,
                dirpath="./checkpoints",
                enable_version_counter=False,
            )
        ],
    )
    trainer.fit(
        model,
        train_dataloaders=train_data,
        val_dataloaders=val_data,
        ckpt_path=ckpt_path,
    )

run(ckpt_path=None)
run(ckpt_path="checkpoints/epoch=0-step=3.ckpt")

Error messages and logs

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/.../lightning/pytorch/callbacks/model_checkpoint.py:652: Checkpoint directory /local/mnt/workspace/pim/projects/equi-scaling/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/.../lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/.../lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
Training torch.Size([2, 32]) -0.8351932168006897 0
/.../lightning/pytorch/callbacks/model_checkpoint.py:382: `ModelCheckpoint(monitor='valid_loss')` could not find the monitored key in the returned metrics: ['train_loss', 'epoch', 'step']. HINT: Did you call `log('valid_loss', value)` in the `LightningModule`?
Training torch.Size([2, 32]) -6.302826881408691 1
Validation torch.Size([2, 32]) -3.5122809410095215
Validation torch.Size([2, 32]) 2.169618844985962
Validation torch.Size([2, 32]) -6.107339859008789
Training torch.Size([2, 32]) -7.338858604431152 2
Training torch.Size([2, 32]) 9.845891952514648 3
Validation torch.Size([2, 32]) -4.606845378875732
Validation torch.Size([2, 32]) 8.005867004394531
Validation torch.Size([2, 32]) -4.361298561096191
`Trainer.fit` stopped: `max_epochs=1` reached.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Restoring states from the checkpoint path at checkpoints/epoch=0-step=3.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Restored all states from the checkpoint at checkpoints/epoch=0-step=3.ckpt
/.../lightning/pytorch/loops/training_epoch_loop.py:156: You're resuming from a checkpoint that ended before the epoch ended. This can cause unreliable results if further training is done. Consider using an end-of-epoch checkpoint
Validation torch.Size([2, 32]) 8.310140609741211   # <--- only one iteration
Training torch.Size([2, 32]) -4.653291702270508 2
Training torch.Size([2, 32]) -8.187255859375 3
Validation torch.Size([2, 32]) 9.539518356323242
Validation torch.Size([2, 32]) -8.617881774902344
Validation torch.Size([2, 32]) 0.5192334651947021
`Trainer.fit` stopped: `max_epochs=1` reached.

Environment

Current environment ``` * CUDA: - GPU: - NVIDIA GeForce RTX 2080 Ti - available: True - version: 11.7 * Lightning: - lightning: 2.2.0.post0 - lightning-utilities: 0.10.1 - pytorch-lightning: 2.2.0.post0 - torch: 2.0.1 - torch-ema: 0.3 - torch-geometric: 2.5.0 - torch-scatter: 2.1.2+pt20cu117 - torchmetrics: 1.3.1 - torchvision: 0.15.2 * System: - OS: Linux - architecture: - 64bit - processor: x86_64 - python: 3.10.13 - release: 5.4.0-152-generic ```

More info

A fix/workaround for this issue, is to add self.batch_progress.reset_on_run() at the end of _EvaluationLoop.run.

cc @carmocca @justusschock

awaelchli commented 6 months ago

@pimdh Thank you for already investigating this.

Since the training loop is quite complex, I can't say for sure this is the right solution but it sounds reasonable. Would you be interested to send a PR with this change? We can then let the full test suite run on your PR and see if there are any edge cases with this. If it works I can help add a test case.

pimdh commented 6 months ago

Hi @awaelchli , I've filed the PR at #19583. While this suffices in my usecase, unfortunately, I won't have time to add unit tests to validate this. Thanks

bkntr commented 3 months ago

The same bug also happens when resuming training after being trained on IterableDataset