Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.08k stars 3.36k forks source link

Step when validation happens drifts for `val_check_interval` when gradient accumulation turned on #17207

Open hrukalive opened 1 year ago

hrukalive commented 1 year ago

Bug description

First of all, my task relies on step count instead of epochs. So I am doing validation checks by steps and saving checkpoints after that. However, as I turned gradient accumulation on, and the batch count is not divisible, I encountered weird drifts for the actual step when the validation is performed, and thus the checkpointing.

In the example below, I override the _save_checkpoint function to monitor the actual file name and it turns out to be drifting. My general setting is val_check_interval=accumulation*5 to make it validate every 5 effective optimizer steps, accumulation=3 and #batches=67 so there is one batch leftover.

How to reproduce the bug

import numpy as np
import pathlib

import time
import torch
import torch.nn as nn
import torch.optim

import lightning.pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint

class Quadratic(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.a = nn.Parameter(torch.tensor(0.0))
        self.b = nn.Parameter(torch.tensor(0.0))
        self.c = nn.Parameter(torch.tensor(0.0))

    def forward(self, x):
        time.sleep(0.02)
        return self.a * x * x + self.b * x + self.c

    def _common_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = torch.nn.functional.mse_loss(y_hat, y)
        return loss 

    def training_step(self, batch, batch_idx):
        loss = self._common_step(batch, batch_idx)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self._common_step(batch, batch_idx)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

class CustomModelCheckpoint(ModelCheckpoint):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
        monitor_candidates = {k: v.item() if isinstance(v, torch.Tensor) else v for k, v in self._monitor_candidates(trainer).items()}
        print("\n", "Save checkpoint, global_step: ", trainer.global_step, pathlib.Path(filepath).stem, "monitor_candidates: " + str(monitor_candidates), "\n", flush=True)

    def _remove_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
        # print("Remove checkpoint: ", filepath, flush=True)
        pass

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('-a', type=float, default=2.0)
    parser.add_argument('-b', type=float, default=3.0)
    parser.add_argument('-c', type=float, default=4.0)
    parser.add_argument('--epoch', type=int, default=500)
    args = parser.parse_args()

    x = torch.from_numpy(np.random.uniform(-10, 10, 2144)).float() # Make 67 batches
    y = args.a * x * x + args.b * x + args.c
    x2 = torch.from_numpy(np.random.uniform(-10, 10, 100)).float()
    y2 = args.a * x2 * x2 + args.b * x2 + args.c

    dataset = torch.utils.data.TensorDataset(x, y)
    val_dataset = torch.utils.data.TensorDataset(x2, y2)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False)

    model = Quadratic()

    ####
    accumulate_grad_batches = 3
    val_check_interval = 5 * accumulate_grad_batches # to make interval for effective batches
    ####

    trainer = pl.Trainer(max_epochs=args.epoch, accelerator='cpu', callbacks=[CustomModelCheckpoint(
                    dirpath='.',
                    filename='steps_{step}',
                    monitor='step',
                    mode='max',
                    save_last=False,
                    save_top_k=5
                )],
            val_check_interval=val_check_interval,
            check_val_every_n_epoch=None,
            num_sanity_val_steps=0,
            accumulate_grad_batches=accumulate_grad_batches)
    trainer.fit(model, dataloader, val_dataloader)

    # Print the results
    print("a = ", model.a.item())
    print("b = ", model.b.item())
    print("c = ", model.c.item())

Error messages and logs

Save checkpoint, global_step:  5 steps_step=5 monitor_candidates: {'epoch': 0, 'step': 5}
Save checkpoint, global_step:  10 steps_step=10 monitor_candidates: {'epoch': 0, 'step': 10}
Save checkpoint, global_step:  15 steps_step=15 monitor_candidates: {'epoch': 0, 'step': 15}
Save checkpoint, global_step:  20 steps_step=20 monitor_candidates: {'epoch': 0, 'step': 20}
Save checkpoint, global_step:  25 steps_step=25 monitor_candidates: {'epoch': 1, 'step': 25}
Save checkpoint, global_step:  30 steps_step=30 monitor_candidates: {'epoch': 1, 'step': 30}
Save checkpoint, global_step:  35 steps_step=35 monitor_candidates: {'epoch': 1, 'step': 35}
Save checkpoint, global_step:  40 steps_step=40 monitor_candidates: {'epoch': 1, 'step': 40}

Save checkpoint, global_step:  46 steps_step=46 monitor_candidates: {'epoch': 2, 'step': 46}  <-- drift
Save checkpoint, global_step:  51 steps_step=51 monitor_candidates: {'epoch': 2, 'step': 51}
Save checkpoint, global_step:  56 steps_step=56 monitor_candidates: {'epoch': 2, 'step': 56}

Environment

Current environment ``` * CUDA: - GPU: - NVIDIA RTX A5000 - NVIDIA RTX A5000 - NVIDIA RTX A5000 - NVIDIA RTX A5000 - available: True - version: 11.7 * Lightning: - lightning: 2.0.0 - lightning-cloud: 0.5.32 - lightning-lite: 1.8.6 - lightning-utilities: 0.8.0 - pytorch-lightning: 2.0.0 - torch: 1.13.1 - torchaudio: 0.13.1 - torchcrepe: 0.0.17 - torchmetrics: 0.11.4 - torchvision: 0.14.1 * Packages: - absl-py: 1.3.0 - aiobotocore: 2.4.2 - aiohttp: 3.8.4 - aioitertools: 0.11.0 - aiosignal: 1.3.1 - altgraph: 0.17.3 - anyio: 3.6.2 - appdirs: 1.4.4 - arrow: 1.2.3 - async-timeout: 4.0.2 - attrs: 22.2.0 - audioread: 3.0.0 - backcall: 0.2.0 - beautifulsoup4: 4.12.0 - blessed: 1.20.0 - blinker: 1.4 - botocore: 1.27.59 - brotlipy: 0.7.0 - cachetools: 5.3.0 - certifi: 2022.12.7 - cffi: 1.15.1 - charset-normalizer: 2.0.4 - click: 8.1.3 - contourpy: 1.0.7 - croniter: 1.3.8 - cryptography: 39.0.1 - cycler: 0.11.0 - dateutils: 0.6.12 - decorator: 5.1.1 - deepdiff: 6.3.0 - distance: 0.1.3 - dnspython: 2.3.0 - einops: 0.6.0 - email-validator: 1.3.1 - et-xmlfile: 1.0.1 - fastapi: 0.88.0 - fire: 0.5.0 - flit-core: 3.8.0 - fonttools: 4.39.2 - frozenlist: 1.3.3 - fsspec: 2023.3.0 - future: 0.18.2 - g2p-en: 2.1.0 - g2pm: 0.1.2.5 - google-auth: 2.16.3 - google-auth-oauthlib: 0.4.6 - grpcio: 1.51.3 - h11: 0.14.0 - h5py: 3.7.0 - httpcore: 0.16.3 - httptools: 0.5.0 - httpx: 0.23.3 - idna: 3.4 - imageio: 2.23.0 - importlib-metadata: 6.1.0 - inflect: 6.0.2 - inquirer: 3.1.3 - itsdangerous: 2.1.2 - jinja2: 3.1.2 - jmespath: 1.0.1 - joblib: 1.2.0 - kiwisolver: 1.4.4 - librosa: 0.9.1 - lightning: 2.0.0 - lightning-cloud: 0.5.32 - lightning-lite: 1.8.6 - lightning-utilities: 0.8.0 - llvmlite: 0.39.1 - markdown: 3.4.3 - markdown-it-py: 2.2.0 - markupsafe: 2.1.2 - matplotlib: 3.6.2 - mdurl: 0.1.2 - mkl-fft: 1.3.1 - mkl-random: 1.2.2 - mkl-service: 2.4.0 - multidict: 6.0.4 - networkx: 3.0 - nltk: 3.8.1 - numba: 0.56.4 - numpy: 1.23.5 - oauthlib: 3.2.2 - ordered-set: 4.1.0 - orjson: 3.8.8 - packaging: 23.0 - pillow: 9.4.0 - pip: 23.0.1 - platformdirs: 3.1.1 - pooch: 1.7.0 - praat-parselmouth: 0.4.3 - protobuf: 3.13.0 - psutil: 5.9.4 - pyasn1: 0.4.8 - pyasn1-modules: 0.2.8 - pycparser: 2.21 - pycwt: 0.3.0a22 - pydantic: 1.10.7 - pygments: 2.14.0 - pyjwt: 2.6.0 - pyloudnorm: 0.1.0 - pyopenssl: 23.0.0 - pyparsing: 3.0.9 - pypinyin: 0.39.0 - pysocks: 1.7.1 - python-dateutil: 2.8.2 - python-dotenv: 1.0.0 - python-editor: 1.0.4 - python-levenshtein: 0.12.2 - python-multipart: 0.0.6 - pytorch-lightning: 2.0.0 - pytz: 2022.7.1 - pywavelets: 1.4.1 - pyyaml: 6.0 - readchar: 4.0.5 - regex: 2023.3.23 - requests: 2.28.1 - requests-oauthlib: 1.3.1 - resampy: 0.4.2 - resemblyzer: 0.1.1.dev0 - rfc3986: 1.5.0 - rich: 13.3.2 - rsa: 4.9 - s3fs: 2023.3.0 - scikit-image: 0.19.3 - scikit-learn: 1.2.2 - scipy: 1.9.3 - setuptools: 65.6.3 - six: 1.16.0 - snakeviz: 2.1.1 - sniffio: 1.3.0 - soundfile: 0.12.1 - soupsieve: 2.4 - starlette: 0.22.0 - starsessions: 1.3.0 - tensorboard: 2.11.0 - tensorboard-data-server: 0.6.1 - tensorboard-plugin-wit: 1.8.1 - tensorboardx: 2.6 - termcolor: 2.2.0 - threadpoolctl: 3.1.0 - tifffile: 2023.3.21 - torch: 1.13.1 - torchaudio: 0.13.1 - torchcrepe: 0.0.17 - torchmetrics: 0.11.4 - torchvision: 0.14.1 - tornado: 6.2 - tqdm: 4.65.0 - traitlets: 5.9.0 - typing: 3.7.4.3 - typing-extensions: 4.4.0 - ujson: 5.7.0 - urllib3: 1.26.14 - uvicorn: 0.21.1 - uvloop: 0.17.0 - watchfiles: 0.18.1 - wcwidth: 0.2.6 - webrtcvad: 2.0.10 - websocket-client: 1.5.1 - websockets: 10.4 - werkzeug: 2.2.3 - wheel: 0.38.4 - wrapt: 1.15.0 - yarl: 1.8.2 - zipp: 3.15.0 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.9.16 - version: #153-Ubuntu SMP Thu Nov 24 15:56:58 UTC 2022 ```

More info

Other than this phenomenon, I have two more questions

  1. Why is val_check_interval tied to the number of batches rather than global_step?
  2. Why is validation re-run after loading a checkpoint just saved after the validation step? This is also going to produce a duplicate checkpoint, which is very frustrating

cc @carmocca @justusschock

hrukalive commented 1 year ago

I think it is actually the moment when validation happens drift. The checkpoint saving is just a side effect.

bkiat1123 commented 1 year ago

Validation check tracks training batches instead of training steps. According to the documentation,

An int value can only be higher than the number of training batches when check_val_every_n_epoch=None, which validates after every N training batches across epochs or during iteration-based training.

However, training batches does not always equal to training steps (global steps).

Training step is (total_batch_idx // accumulate_grad_batches) + (accumulates_on_final_batch * epoch_trained). The accumulates_on_final_batch is where the draft happens.

I think it would make sense to validate after N training steps instead of training batches. Other module such as Logger and Model Checkpoint use global steps to track training steps too.

I propose we can change from

https://github.com/Lightning-AI/lightning/blob/83f683243dde71898e2110b12fd5c78ebac5418b/src/lightning/pytorch/loops/training_epoch_loop.py#L392-L398

to

elif self.trainer.val_check_batch != float("inf"):
    # if `check_val_every_n_epoch is` None`, run a validation loop every n training steps
    # else condition it based on the batch_idx of the current epoch
    next_iteration = self.global_step if self.trainer.check_val_every_n_epoch is None else self.batch_idx + 1
    is_val_check_batch = next_iteration % self.trainer.val_check_batch == 0
anjali-chadha commented 6 months ago

Is there a plan to add step-based validation checks in Lightning?

Until Lightning adds official support for this, any recommendations on how we can override the default Lightning behavior and use number of steps instead of batches to trigger the validation?

hrukalive commented 5 months ago

Is there a plan to add step-based validation checks in Lightning?

Until Lightning adds official support for this, any recommendations on how we can override the default Lightning behavior and use number of steps instead of batches to trigger the validation?

Right now, for myself, I have to discard the last batch to make steps multiples of accum grad.