Restoring a checkpoint (from model checkpoint) via `trainer.test(model_module, ckpt_path='best')` doesn't restore associated `current_epoch` within the trainer #19558

Open maciejzj opened 7 months ago

maciejzj commented 7 months ago

Bug description

When a checkpoint (presumably the 'best' one saved by model checkpoint monitor) is restored to be used for testing via trainer.test(model_module, ckpt_path='best') the current_epoch member of the trainer is not restored to the one associated with the checkpoint, even though the checkpoint file stores the correct epoch value.

I am not sure whether this is intended (seems rather incorrect to me). I asked a question about this problem at the Discord forum before, but got no response, so I decided to open an issue here.

What version are you seeing the problem on?


How to reproduce the bug

import os

import torch
from lightning.pytorch import LightningModule, Trainer, callbacks, seed_everything
from import DataLoader, Dataset

class RandomDataset(Dataset):
    def __init__(self, size, length, offset = 0):
        self.len = length = torch.randn(length, size) + offset

    def __getitem__(self, index):

    def __len__(self):
        return self.len

class BoringModel(LightningModule):
    def __init__(self):
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        print(f'CURRENT EPOCH: {self.current_epoch}')
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)

def run():

    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64, offset=16), batch_size=2)
    test_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        callbacks=[(mc := callbacks.ModelCheckpoint(monitor="valid_loss", verbose=True))],
    ), train_dataloaders=train_data, val_dataloaders=val_data)
    trainer.test(model, dataloaders=test_data, ckpt_path="best")
    print(f'CKPT EPOCH: {torch.load(mc.best_model_path, map_location="cpu")["epoch"]}')

if __name__ == "__main__":

Error messages and logs

Epoch 0: 100%|███████████████████████████████████████████| 32/32 [00:00<00:00, 277.12it/s, v_num=7]
Epoch 0, global step 32: 'valid_loss' reached 183.69585 (best 183.69585), saving model to '/Users/maciej/Desktop/lightning_logs/version_7/checkpoints/epoch=0-step=32.ckpt' as top 1
Epoch 1: 100%|███████████████████████████████████████████| 32/32 [00:00<00:00, 432.74it/s, v_num=7]
Epoch 1, global step 64: 'valid_loss' was not in top 1                                             
`` stopped: `max_epochs=2` reached.
Epoch 1: 100%|███████████████████████████████████████████| 32/32 [00:00<00:00, 410.05it/s, v_num=7]
Restoring states from the checkpoint path at /Users/maciej/Desktop/lightning_logs/version_7/checkpoints/epoch=0-step=32.ckpt
Loaded model weights from the checkpoint at /Users/maciej/Desktop/lightning_logs/version_7/checkpoints/epoch=0-step=32.ckpt
/Users/maciej/.local/share/pyenv/versions/3.11.7/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/ The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.
Testing DataLoader 0:   0%|                                                  | 0/1 [00:00<?, ?it/s]
Testing DataLoader 0: 100%|█████████████████████████████████████████| 1/1 [00:00<00:00, 380.95it/s]
┃        Test metric        ┃       DataLoader 0        ┃
│         test_loss         │    -40.08610534667969     │

Notice that current epoch printed by the test_step after loading "best" does not match the one inside the checkpoint. This indicates that loading the 'best' model doesn't restore the epoch number. Should it?


Current environment
* CUDA:
  - GPU: None
  - available: False
  - version: None
* Lightning:
  - lightning: 2.2.0.post0
  - pytorch-lightning: 2.2.0.post0
  - torch: 2.2.1
* System:
  - OS: Darwin
  - python: 3.11.7 cycler: 0.12.1 - dash: 2.15.0 - dash-core-components: 2.0.0 - dash-html-components: 2.0.0 - dash-table: 5.0.0 - debugpy: 1.8.1 - decorator: 5.1.1 - defusedxml: 0.7.1 - dictdiffer: 0.9.0 - dill: 0.3.8 - diskcache: 5.6.3 - distlib: 0.3.8 - distro: 1.9.0 - docopt: 0.6.2 - dpath: 2.1.6 - dulwich: 0.21.7 - dvc: 3.48.0 - dvc-data: 3.13.0 - dvc-http: 2.32.0 - dvc-objects: 5.0.0 - dvc-render: 1.0.1 - dvc-studio-client: 0.20.0 - dvc-task: 0.3.0 - einops: 0.7.0 - entrypoints: 0.4 - executing: 2.0.1 - fastjsonschema: 2.19.1 - filelock: 3.13.1 - flake8: 7.0.0 - flask: 3.0.2 - flatten-dict: 0.4.2 - flufl.lock: 7.1.1 - fonttools: 4.49.0 - fqdn: 1.5.1 - frozenlist: 1.4.1 - fsspec: 2024.2.0 - funcy: 2.0 - gitdb: 4.0.11 - gitpython: 3.1.42 - grandalf: 0.8 - greenlet: 3.0.3 - gto: 1.7.0 - h11: 0.14.0 - h5py: 3.10.0 - httpcore: 1.0.4 - httpx: 0.27.0 - hydra-core: 1.3.2 - idna: 3.6 - imageio: 2.34.0 - importlib-metadata: 7.0.1 - iniconfig: 2.0.0 - ipdb: 0.13.13 - ipykernel: 6.29.3 - ipython: 8.12.3 - isoduration: 20.11.0 - isort: 5.13.2 - iterative-telemetry: 0.0.8 - itsdangerous: 2.1.2 - jedi: 0.19.1 - jinja2: 3.1.3 - joblib: 1.3.2 - json5: 0.9.17 - jsonpointer: 2.4 - jsonschema: 4.21.1 - jsonschema-specifications: 2023.12.1 - jupyter-client: 8.6.0 - jupyter-core: 5.7.1 - jupyter-events: 0.9.0 - jupyter-lsp: 2.2.3 - jupyter-server: 2.12.5 - jupyter-server-mathjax: 0.2.6 - jupyter-server-terminals: 0.5.2 - jupyterlab: 4.1.2 - jupyterlab-pygments: 0.3.0 - jupyterlab-server: 2.25.3 - kiwisolver: 1.4.5 - kombu: 5.3.5 - lazy-loader: 0.3 - lightning: 2.2.0.post0 - lightning-utilities: 0.10.1 - markdown-it-py: 3.0.0 - markupsafe: 2.1.5 - matplotlib: 3.8.3 - matplotlib-inline: 0.1.6 - mccabe: 0.7.0 - mdurl: 0.1.2 - mistune: 3.0.2 - mpmath: 1.3.0 - msgpack: 1.0.7 - multidict: 6.0.5 - mypy: 1.8.0 - mypy-extensions: 1.0.0 - nbclient: 0.9.0 - nbconvert: 7.16.1 - nbdime: 4.0.1 - nbformat: 5.9.2 - neovim: 0.3.1 - nest-asyncio: 1.6.0 - networkx: 3.2.1 - notebook-shim: 0.2.4 - numpy: 1.26.4 - omegaconf: 2.3.0 - orjson: 3.9.15 - overrides: 7.7.0 - packaging: 23.2 - pandas: 2.2.1 - pandocfilters: 1.5.1 - parso: 0.8.3 - pathspec: 0.12.1 - pexpect: 4.9.0 - pickleshare: 0.7.5 - pillow: 10.2.0 - pip: 23.2.1 - pipreqs: 0.5.0 - platformdirs: 3.11.0 - plotly: 5.19.0 - pluggy: 1.4.0 - prometheus-client: 0.20.0 - prompt-toolkit: 3.0.43 - psutil: 5.9.8 - ptyprocess: 0.7.0 - pure-eval: 0.2.2 - pycodestyle: 2.11.1 - pycparser: 2.21 - pydantic: 2.6.3 - pydantic-core: 2.16.3 - pydocstyle: 6.3.0 - pydot: 2.0.0 - pyflakes: 3.2.0 - pygit2: 1.14.1 - pygments: 2.17.2 - pygtrie: 2.5.0 - pylint: 3.1.0 - pynvim: 0.5.0 - pyparsing: 3.1.1 - pytest: 8.0.2 - pytest-mock: 3.12.0 - python-dateutil: 2.8.2 - python-json-logger: 2.0.7 - pytoolconfig: 1.3.1 - pytorch-lightning: 2.2.0.post0 - pytz: 2024.1 - pyyaml: 6.0.1 - pyzmq: 25.1.2 - referencing: 0.33.0 - rentry: 1.0.1 - requests: 2.31.0 - retrying: 1.3.4 - rfc3339-validator: 0.1.4 - rfc3986-validator: 0.1.1 - rich: 13.7.0 - rope: 1.12.0 - rpds-py: 0.18.0 - ruamel.yaml: 0.18.6 - ruamel.yaml.clib: 0.2.8 - scikit-image: 0.22.0 - scikit-learn: 1.4.1.post1 - scipy: 1.12.0 - scmrepo: 3.1.0 - semver: 3.0.2 - send2trash: 1.8.2 - setuptools: 65.5.0 - shortuuid: 1.0.11 - shtab: 1.7.0 - six: 1.16.0 - smmap: 5.0.1 - sniffio: 1.3.1 - snowballstemmer: 2.2.0 - soupsieve: 2.5 - sqltrie: 0.11.0 - stack-data: 0.6.3 - sympy: 1.12 - tabulate: 0.9.0 - tenacity: 8.2.3 - terminado: 0.18.0 - threadpoolctl: 3.3.0 - tifffile: 2024.2.12 - tinycss2: 1.2.1 - tomlkit: 0.12.4 - torch: 2.2.1 - torchmetrics: 1.3.1 - torchvision: 0.17.1 - tornado: 6.4 - tqdm: 4.66.2 - traitlets: 5.14.1 - typer: 0.9.0 - types-python-dateutil: - typing-extensions: 4.10.0 - tzdata: 2024.1 - uri-template: 1.3.0 - urllib3: 2.2.1 - vine: 5.1.0 - virtualenv: 20.25.1 - voluptuous: 0.14.2 - wcwidth: 0.2.13 - webcolors: 1.13 - webencodings: 0.5.1 - websocket-client: 1.7.0 - werkzeug: 3.0.1 - wheel: 0.42.0 - yarg: 0.1.9 - yarl: 1.9.4 - zc.lockfile: 3.0.post1 - zipp: 3.17.0 * System: - OS: Darwin - architecture: - 64bit - - processor: arm - python: 3.11.7 - release: 23.2.0 - version: Darwin Kernel Version 23.2.0: Wed Nov 15 21:53:18 PST 2023; root:xnu-10002.61.3~2/RELEASE_ARM64_T6000 ```

maorag commented 2 weeks ago

faced the same issue until there is a more elegant solution i solved it by this workaround

class BoringModel(LightningModule):
    def init(self):
        self.layer = torch.nn.Linear(32, 2)

    def on_load_checkpoint(self,checkpoint)->None:
        return super().on_load_checkpoint(checkpoint)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        print(f'CURRENT EPOCH: { self.ckpt_current_epoch}')
        self.log("test_loss", loss)