trainer.validate() get different result from #20179

Open matrix72c opened 3 months ago

matrix72c commented 3 months ago

Bug description

I'm training a ResNet50 model and using model checkpoint(only weight) to save the best model. However I found the results are different in fit and validate. I start training with lightning clli and set run=False, just manually call fit and validate.

What version are you seeing the problem on?


How to reproduce the bug

Main code

import lightning as L
import torch
import torchvision
from torch import nn
import torch.nn.functional as F
from torchmetrics import Accuracy
from lightning.pytorch.cli import LightningCLI

class ResNet(L.LightningModule):
    def __init__(
        num_classes: int,
        use_pretrained: bool = True,
        lr: float = 1e-3,
        step_size: int = 10,
        gamma: float = 0.1,

        self.model = torchvision.models.resnet50(
                torchvision.models.ResNet50_Weights.DEFAULT if use_pretrained else None
        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
        self.acc = Accuracy(task="multiclass", num_classes=num_classes)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(),
        scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, step_size=self.hparams.step_size, gamma=self.hparams.gamma
        return [optimizer], [scheduler]

    def forward(self, x):
        x = self.model(x)
        return x

    def shared_step(self, batch):
        img, label, _ = batch
        logits = self(img)
        loss = F.cross_entropy(logits, label)
        self.acc(logits, label)
        return loss

    def training_step(self, batch, batch_idx):
        loss = self.shared_step(batch)
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_acc", self.acc, prog_bar=True, on_step=True, on_epoch=False)
        return loss

    def validation_step(self, batch, batch_idx):
        _ = self.shared_step(batch)
        self.log("val_acc", self.acc, prog_bar=True, on_epoch=True, on_step=True)

    def test_step(self, batch, batch_idx):
        _ = self.shared_step(batch)
        self.log("test_acc", self.acc, prog_bar=True, on_epoch=True, on_step=True)

if __name__ == "__main__":
    cli = LightningCLI(save_config_callback=None, run=False), cli.datamodule)
    cli.trainer.validate(cli.model, cli.datamodule, ckpt_path="checkpoint/resnet_CUB.ckpt")

config file

seed_everything: 42

  max_epochs: 1000
  log_every_n_steps: 1
    class_path: aim.pytorch_lightning.AimLogger
      run_name: resnet_CUB
    - class_path: lightning.pytorch.callbacks.ModelCheckpoint
        monitor: val_acc
        dirpath: checkpoints/
        filename: resnet_CUB
        save_top_k: 1
        mode: max
        save_weights_only: True
        enable_version_counter: False
    - class_path: lightning.pytorch.callbacks.EarlyStopping
        monitor: val_acc
        mode: max
        min_delta: 0.001
        patience: 0

  class_path: model.ResNet
    num_classes: 200

  class_path: dataset.CUB
    data_path: ./data
    batch_size: 128

Error messages and logs

In fit, the final acc_epoch is 0.86, while in trainer.validate, it becomes 0.74.


Current environment * CUDA: - GPU: - NVIDIA GeForce RTX 4090 - available: True - version: 11.8 * Lightning: - lightning: 2.2.5 - lightning-utilities: 0.11.2 - pytorch-lightning: 2.2.2 - torch: 2.1.1 - torchattacks: 3.5.1 - torchaudio: 2.1.1 - torchmetrics: 1.4.0 - torchvision: 0.16.1 * Packages: - absl-py: 2.1.0 - aim: 3.22.0 - aim-ui: 3.22.0 - aimrecords: 0.0.7 - aimrocks: 0.5.2 - aiofiles: 23.2.1 - aiohttp: 3.9.5 - aiosignal: 1.3.1 - alembic: 1.13.0 - annotated-types: 0.6.0 - antlr4-python3-runtime: 4.9.3 - anyio: 3.7.1 - argcomplete: 3.4.0 - argon2-cffi: 23.1.0 - argon2-cffi-bindings: 21.2.0 - arrow: 1.3.0 - asttokens: 2.4.1 - astunparse: 1.6.3 - async-lru: 2.0.4 - async-timeout: 4.0.3 - attrs: 23.1.0 - babel: 2.14.0 - backoff: 2.2.1 - base58: 2.0.1 - beautifulsoup4: 4.12.3 - bitsandbytes: 0.41.0 - black: 24.2.0 - bleach: 6.1.0 - boto3: 1.34.62 - botocore: 1.34.62 - bottleneck: 1.3.5 - brotli: 1.0.9 - cachetools: 5.3.2 - certifi: 2024.7.4 - cffi: 1.16.0 - chardet: 4.0.0 - charset-normalizer: 2.0.4 - click: 8.1.7 - colorama: 0.4.6 - comm: 0.2.1 - contourpy: 1.2.0 - cryptography: 41.0.7 - cycler: 0.11.0 - debugpy: 1.6.7 - decorator: 5.1.1 - defusedxml: 0.7.1 - docopt: 0.6.2 - docstring-parser: 0.16 - exceptiongroup: 1.2.0 - executing: 2.0.1 - fastapi: 0.104.1 - fastjsonschema: 2.19.1 - filelock: 3.13.1 - fonttools: 4.25.0 - fqdn: 1.5.1 - frozenlist: 1.4.1 - fsspec: 2023.10.0 - gmpy2: 2.1.2 - greenlet: 3.0.2 - grpcio: 1.48.2 - h11: 0.14.0 - httpcore: 1.0.4 - httpx: 0.27.0 - hydra-core: 1.3.2 - idna: 2.10 - importlib-metadata: 7.0.1 - importlib-resources: 6.1.0 - ipyflow: 0.0.198 - ipyflow-core: 0.0.198 - ipykernel: 6.29.0 - ipython: 8.18.1 - ipython-genutils: 0.2.0 - ipywidgets: 8.1.2 - isoduration: 20.11.0 - jedi: 0.19.1 - jinja2: 3.1.2 - jmespath: 1.0.1 - joblib: 1.3.2 - json5: 0.9.17 - jsonargparse: 4.29.0 - jsonnet: 0.17.0 - jsonpointer: 2.4 - jsonschema: 4.19.2 - jsonschema-specifications: 2023.12.1 - jupyter: 1.0.0 - jupyter-client: 8.6.0 - jupyter-console: 6.6.3 - jupyter-core: 5.7.1 - jupyter-events: 0.9.0 - jupyter-lsp: 2.2.3 - jupyter-server: 2.12.5 - jupyter-server-terminals: 0.5.2 - jupyterlab: 4.1.2 - jupyterlab-pygments: 0.3.0 - jupyterlab-server: 2.25.3 - jupyterlab-widgets: 3.0.10 - kiwisolver: 1.4.4 - lightning: 2.2.5 - lightning-utilities: 0.11.2 - mako: 1.3.0 - markdown: 3.6 - markdown-it-py: 3.0.0 - markupsafe: 2.1.1 - matplotlib: 3.8.0 - matplotlib-inline: 0.1.6 - mdurl: 0.1.2 - mistune: 3.0.2 - mkl-fft: 1.3.8 - mkl-random: 1.2.4 - mkl-service: 2.4.0 - monotonic: 1.6 - mpmath: 1.3.0 - multidict: 6.0.5 - munkres: 1.1.4 - mypy-extensions: 1.0.0 - nbclassic: 1.0.0 - nbclient: 0.9.0 - nbconvert: 7.16.1 - nbformat: 5.9.2 - nest-asyncio: 1.6.0 - networkx: 3.1 - notebook: 7.1.1 - notebook-shim: 0.2.4 - numexpr: 2.8.7 - numpy: 1.26.2 - omegaconf: 2.3.0 - overrides: 7.7.0 - packaging: 23.1 - pandas: 2.1.1 - pandocfilters: 1.5.1 - parso: 0.8.3 - pathspec: 0.12.1 - pexpect: 4.8.0 - pickleshare: 0.7.5 - pillow: 10.0.1 - pip: 23.3.1 - pipreqs: 0.4.13 - platformdirs: 4.1.0 - ply: 3.11 - pretty-errors: 1.2.25 - prometheus-client: 0.20.0 - prompt-toolkit: 3.0.42 - protobuf: 3.20.3 - psutil: 5.9.1 - ptyprocess: 0.7.0 - pure-eval: 0.2.2 - py3nvml: 0.2.7 - pyccolo: 0.0.52 - pycparser: 2.21 - pydantic: 2.5.2 - pydantic-core: 2.14.5 - pygments: 2.17.2 - pyopenssl: 23.2.0 - pyparsing: 3.0.9 - pyqt5: 5.15.10 - pyqt5-sip: 12.13.0 - pysocks: 1.7.1 - python-dateutil: 2.8.2 - python-json-logger: 2.0.7 - pytorch-lightning: 2.2.2 - pytz: 2023.3.post1 - pyyaml: 6.0.1 - pyzmq: 25.1.2 - qtconsole: 5.5.1 - qtpy: 2.4.1 - referencing: 0.30.2 - requests: 2.25.1 - restrictedpython: 7.0 - rfc3339-validator: 0.1.4 - rfc3986-validator: 0.1.1 - rich: 13.7.1 - rpds-py: 0.10.6 - s3transfer: 0.10.0 - scikit-learn: 1.4.1.post1 - scipy: 1.12.0 - seaborn: 0.12.2 - segment-analytics-python: 2.2.3 - send2trash: 1.8.2 - setuptools: 68.0.0 - sip: 6.7.12 - six: 1.16.0 - sniffio: 1.3.0 - soupsieve: 2.5 - sqlalchemy: 1.4.50 - stack-data: 0.6.2 - starlette: 0.27.0 - sympy: 1.12 - tensorboard: 2.17.0 - tensorboard-data-server: 0.7.0 - tensorboardx: - terminado: 0.18.0 - threadpoolctl: 3.4.0 - tinycss2: 1.2.1 - tomli: 2.0.1 - torch: 2.1.1 - torchattacks: 3.5.1 - torchaudio: 2.1.1 - torchmetrics: 1.4.0 - torchvision: 0.16.1 - tornado: 6.3.3 - tqdm: 4.65.0 - traitlets: 5.14.1 - triton: 2.1.0 - types-python-dateutil: - typeshed-client: 2.5.1 - typing-extensions: 4.9.0 - tzdata: 2023.3 - uri-template: 1.3.0 - urllib3: 1.26.18 - uvicorn: 0.24.0.post1 - validators: 0.18.2 - wcwidth: 0.2.13 - webcolors: 1.13 - webencodings: 0.5.1 - websocket-client: 1.7.0 - websockets: 12.0 - werkzeug: 3.0.3 - wheel: 0.41.2 - widgetsnbextension: 4.0.10 - xmltodict: 0.13.0 - yarg: 0.1.9 - yarl: 1.9.4 - zipp: 3.11.0 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.9.18 - release: 6.2.0-39-generic - version: #40-Ubuntu SMP PREEMPT_DYNAMIC Tue Nov 14 14:18:00 UTC 2023

matrix72c commented 2 months ago

I found if I forward acc only in the validate step and test step, the results are the same. However, when using one acc metric in both validate and train, the validate step result is different from single calling validate().