Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.26k stars 3.38k forks source link

Restoring a checkpoint (from model checkpoint) via `trainer.test(model_module, ckpt_path='best')` doesn't restore associated `current_epoch` within the trainer #19558

Open maciejzj opened 7 months ago

maciejzj commented 7 months ago

Bug description

When a checkpoint (presumably the 'best' one saved by model checkpoint monitor) is restored to be used for testing via trainer.test(model_module, ckpt_path='best') the current_epoch member of the trainer is not restored to the one associated with the checkpoint, even though the checkpoint file stores the correct epoch value.

I am not sure whether this is intended (seems rather incorrect to me). I asked a question about this problem at the Discord forum before, but got no response, so I decided to open an issue here.

What version are you seeing the problem on?

v2.2

How to reproduce the bug

import os

import torch
from lightning.pytorch import LightningModule, Trainer, callbacks, seed_everything
from torch.utils.data import DataLoader, Dataset

class RandomDataset(Dataset):
    def __init__(self, size, length, offset = 0):
        self.len = length
        self.data = torch.randn(length, size) + offset

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len

class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        print(f'CURRENT EPOCH: {self.current_epoch}')
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)

def run():
    seed_everything(10)

    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64, offset=16), batch_size=2)
    test_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        num_sanity_val_steps=0,
        limit_train_batches=1,
        limit_val_batches=1,
        limit_test_batches=1,
        max_epochs=2,
        enable_model_summary=False,
        callbacks=[(mc := callbacks.ModelCheckpoint(monitor="valid_loss", verbose=True))],
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
    trainer.test(model, dataloaders=test_data, ckpt_path="best")
    print(f'CKPT EPOCH: {torch.load(mc.best_model_path, map_location="cpu")["epoch"]}')

if __name__ == "__main__":
    run()

Error messages and logs

Epoch 0: 100%|███████████████████████████████████████████| 32/32 [00:00<00:00, 277.12it/s, v_num=7]
Epoch 0, global step 32: 'valid_loss' reached 183.69585 (best 183.69585), saving model to '/Users/maciej/Desktop/lightning_logs/version_7/checkpoints/epoch=0-step=32.ckpt' as top 1
Epoch 1: 100%|███████████████████████████████████████████| 32/32 [00:00<00:00, 432.74it/s, v_num=7]
Epoch 1, global step 64: 'valid_loss' was not in top 1                                             
`Trainer.fit` stopped: `max_epochs=2` reached.
Epoch 1: 100%|███████████████████████████████████████████| 32/32 [00:00<00:00, 410.05it/s, v_num=7]
Restoring states from the checkpoint path at /Users/maciej/Desktop/lightning_logs/version_7/checkpoints/epoch=0-step=32.ckpt
Loaded model weights from the checkpoint at /Users/maciej/Desktop/lightning_logs/version_7/checkpoints/epoch=0-step=32.ckpt
/Users/maciej/.local/share/pyenv/versions/3.11.7/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.
Testing DataLoader 0:   0%|                                                  | 0/1 [00:00<?, ?it/s]
CURRENT EPOCH: 2
Testing DataLoader 0: 100%|█████████████████████████████████████████| 1/1 [00:00<00:00, 380.95it/s]
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric        ┃       DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│         test_loss         │    -40.08610534667969     │
└───────────────────────────┴───────────────────────────┘
CKPT EPOCH: 0

Notice that current epoch printed by the test_step after loading "best" does not match the one inside the checkpoint. This indicates that loading the 'best' model doesn't restore the epoch number. Should it?

Environment

Current environment ``` * CUDA: - GPU: None - available: False - version: None * Lightning: - lightning: 2.2.0.post0 - lightning-utilities: 0.10.1 - pytorch-lightning: 2.2.0.post0 - torch: 2.2.1 - torchmetrics: 1.3.1 - torchvision: 0.17.1 * Packages: - aiohttp: 3.9.3 - aiohttp-retry: 2.8.3 - aiosignal: 1.3.1 - amqp: 5.2.0 - annotated-types: 0.6.0 - antlr4-python3-runtime: 4.9.3 - anyio: 4.3.0 - appdirs: 1.4.4 - appnope: 0.1.4 - argon2-cffi: 23.1.0 - argon2-cffi-bindings: 21.2.0 - arrow: 1.3.0 - astroid: 3.1.0 - asttokens: 2.4.1 - async-lru: 2.0.4 - asyncssh: 2.14.2 - atpublic: 4.0 - attrs: 23.2.0 - autopep8: 2.0.4 - babel: 2.14.0 - backcall: 0.2.0 - beautifulsoup4: 4.12.3 - billiard: 4.2.0 - black: 24.2.0 - bleach: 6.1.0 - blinker: 1.7.0 - celery: 5.3.6 - certifi: 2024.2.2 - cffi: 1.16.0 - charset-normalizer: 3.3.2 - click: 8.1.7 - click-didyoumean: 0.3.0 - click-plugins: 1.1.1 - click-repl: 0.3.0 - colorama: 0.4.6 - comm: 0.2.1 - configobj: 5.0.8 - contourpy: 1.2.0 - cryptography: 42.0.5 - cycler: 0.12.1 - dash: 2.15.0 - dash-core-components: 2.0.0 - dash-html-components: 2.0.0 - dash-table: 5.0.0 - debugpy: 1.8.1 - decorator: 5.1.1 - defusedxml: 0.7.1 - dictdiffer: 0.9.0 - dill: 0.3.8 - diskcache: 5.6.3 - distlib: 0.3.8 - distro: 1.9.0 - docopt: 0.6.2 - dpath: 2.1.6 - dulwich: 0.21.7 - dvc: 3.48.0 - dvc-data: 3.13.0 - dvc-http: 2.32.0 - dvc-objects: 5.0.0 - dvc-render: 1.0.1 - dvc-studio-client: 0.20.0 - dvc-task: 0.3.0 - einops: 0.7.0 - entrypoints: 0.4 - executing: 2.0.1 - fastjsonschema: 2.19.1 - filelock: 3.13.1 - flake8: 7.0.0 - flask: 3.0.2 - flatten-dict: 0.4.2 - flufl.lock: 7.1.1 - fonttools: 4.49.0 - fqdn: 1.5.1 - frozenlist: 1.4.1 - fsspec: 2024.2.0 - funcy: 2.0 - gitdb: 4.0.11 - gitpython: 3.1.42 - grandalf: 0.8 - greenlet: 3.0.3 - gto: 1.7.0 - h11: 0.14.0 - h5py: 3.10.0 - httpcore: 1.0.4 - httpx: 0.27.0 - hydra-core: 1.3.2 - idna: 3.6 - imageio: 2.34.0 - importlib-metadata: 7.0.1 - iniconfig: 2.0.0 - ipdb: 0.13.13 - ipykernel: 6.29.3 - ipython: 8.12.3 - isoduration: 20.11.0 - isort: 5.13.2 - iterative-telemetry: 0.0.8 - itsdangerous: 2.1.2 - jedi: 0.19.1 - jinja2: 3.1.3 - joblib: 1.3.2 - json5: 0.9.17 - jsonpointer: 2.4 - jsonschema: 4.21.1 - jsonschema-specifications: 2023.12.1 - jupyter-client: 8.6.0 - jupyter-core: 5.7.1 - jupyter-events: 0.9.0 - jupyter-lsp: 2.2.3 - jupyter-server: 2.12.5 - jupyter-server-mathjax: 0.2.6 - jupyter-server-terminals: 0.5.2 - jupyterlab: 4.1.2 - jupyterlab-pygments: 0.3.0 - jupyterlab-server: 2.25.3 - kiwisolver: 1.4.5 - kombu: 5.3.5 - lazy-loader: 0.3 - lightning: 2.2.0.post0 - lightning-utilities: 0.10.1 - markdown-it-py: 3.0.0 - markupsafe: 2.1.5 - matplotlib: 3.8.3 - matplotlib-inline: 0.1.6 - mccabe: 0.7.0 - mdurl: 0.1.2 - mistune: 3.0.2 - mpmath: 1.3.0 - msgpack: 1.0.7 - multidict: 6.0.5 - mypy: 1.8.0 - mypy-extensions: 1.0.0 - nbclient: 0.9.0 - nbconvert: 7.16.1 - nbdime: 4.0.1 - nbformat: 5.9.2 - neovim: 0.3.1 - nest-asyncio: 1.6.0 - networkx: 3.2.1 - notebook-shim: 0.2.4 - numpy: 1.26.4 - omegaconf: 2.3.0 - orjson: 3.9.15 - overrides: 7.7.0 - packaging: 23.2 - pandas: 2.2.1 - pandocfilters: 1.5.1 - parso: 0.8.3 - pathspec: 0.12.1 - pexpect: 4.9.0 - pickleshare: 0.7.5 - pillow: 10.2.0 - pip: 23.2.1 - pipreqs: 0.5.0 - platformdirs: 3.11.0 - plotly: 5.19.0 - pluggy: 1.4.0 - prometheus-client: 0.20.0 - prompt-toolkit: 3.0.43 - psutil: 5.9.8 - ptyprocess: 0.7.0 - pure-eval: 0.2.2 - pycodestyle: 2.11.1 - pycparser: 2.21 - pydantic: 2.6.3 - pydantic-core: 2.16.3 - pydocstyle: 6.3.0 - pydot: 2.0.0 - pyflakes: 3.2.0 - pygit2: 1.14.1 - pygments: 2.17.2 - pygtrie: 2.5.0 - pylint: 3.1.0 - pynvim: 0.5.0 - pyparsing: 3.1.1 - pytest: 8.0.2 - pytest-mock: 3.12.0 - python-dateutil: 2.8.2 - python-json-logger: 2.0.7 - pytoolconfig: 1.3.1 - pytorch-lightning: 2.2.0.post0 - pytz: 2024.1 - pyyaml: 6.0.1 - pyzmq: 25.1.2 - referencing: 0.33.0 - rentry: 1.0.1 - requests: 2.31.0 - retrying: 1.3.4 - rfc3339-validator: 0.1.4 - rfc3986-validator: 0.1.1 - rich: 13.7.0 - rope: 1.12.0 - rpds-py: 0.18.0 - ruamel.yaml: 0.18.6 - ruamel.yaml.clib: 0.2.8 - scikit-image: 0.22.0 - scikit-learn: 1.4.1.post1 - scipy: 1.12.0 - scmrepo: 3.1.0 - semver: 3.0.2 - send2trash: 1.8.2 - setuptools: 65.5.0 - shortuuid: 1.0.11 - shtab: 1.7.0 - six: 1.16.0 - smmap: 5.0.1 - sniffio: 1.3.1 - snowballstemmer: 2.2.0 - soupsieve: 2.5 - sqltrie: 0.11.0 - stack-data: 0.6.3 - sympy: 1.12 - tabulate: 0.9.0 - tenacity: 8.2.3 - terminado: 0.18.0 - threadpoolctl: 3.3.0 - tifffile: 2024.2.12 - tinycss2: 1.2.1 - tomlkit: 0.12.4 - torch: 2.2.1 - torchmetrics: 1.3.1 - torchvision: 0.17.1 - tornado: 6.4 - tqdm: 4.66.2 - traitlets: 5.14.1 - typer: 0.9.0 - types-python-dateutil: 2.8.19.20240106 - typing-extensions: 4.10.0 - tzdata: 2024.1 - uri-template: 1.3.0 - urllib3: 2.2.1 - vine: 5.1.0 - virtualenv: 20.25.1 - voluptuous: 0.14.2 - wcwidth: 0.2.13 - webcolors: 1.13 - webencodings: 0.5.1 - websocket-client: 1.7.0 - werkzeug: 3.0.1 - wheel: 0.42.0 - yarg: 0.1.9 - yarl: 1.9.4 - zc.lockfile: 3.0.post1 - zipp: 3.17.0 * System: - OS: Darwin - architecture: - 64bit - - processor: arm - python: 3.11.7 - release: 23.2.0 - version: Darwin Kernel Version 23.2.0: Wed Nov 15 21:53:18 PST 2023; root:xnu-10002.61.3~2/RELEASE_ARM64_T6000 ```

More info

No response

maorag commented 2 weeks ago

faced the same issue until there is a more elegant solution i solved it by this workaround


class BoringModel(LightningModule):
    def init(self):
        super().init()
        self.layer = torch.nn.Linear(32, 2)

    def on_load_checkpoint(self,checkpoint)->None:
        self.ckpt_current_epoch=checkpoint['epoch']
        self.ckpt_global_step=checkpoint['global_step']
        return super().on_load_checkpoint(checkpoint)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        print(f'CURRENT EPOCH: { self.ckpt_current_epoch}')
        self.log("test_loss", loss)