Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.39k stars 3.38k forks source link

trainer.validate() get different result from trainer.fit #20179

Open matrix72c opened 3 months ago

matrix72c commented 3 months ago

Bug description

I'm training a ResNet50 model and using model checkpoint(only weight) to save the best model. However I found the results are different in fit and validate. I start training with lightning clli and set run=False, just manually call fit and validate.

What version are you seeing the problem on?

v2.2

How to reproduce the bug

Main code


import lightning as L
import torch
import torchvision
from torch import nn
import torch.nn.functional as F
from torchmetrics import Accuracy
from lightning.pytorch.cli import LightningCLI

class ResNet(L.LightningModule):
    def __init__(
        self,
        num_classes: int,
        use_pretrained: bool = True,
        lr: float = 1e-3,
        step_size: int = 10,
        gamma: float = 0.1,
    ):
        super().__init__()
        self.save_hyperparameters()

        self.model = torchvision.models.resnet50(
            weights=(
                torchvision.models.ResNet50_Weights.DEFAULT if use_pretrained else None
            ),
        )
        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
        self.acc = Accuracy(task="multiclass", num_classes=num_classes)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
        scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, step_size=self.hparams.step_size, gamma=self.hparams.gamma
        )
        return [optimizer], [scheduler]

    def forward(self, x):
        x = self.model(x)
        return x

    def shared_step(self, batch):
        img, label, _ = batch
        logits = self(img)
        loss = F.cross_entropy(logits, label)
        self.acc(logits, label)
        return loss

    def training_step(self, batch, batch_idx):
        loss = self.shared_step(batch)
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_acc", self.acc, prog_bar=True, on_step=True, on_epoch=False)
        return loss

    def validation_step(self, batch, batch_idx):
        _ = self.shared_step(batch)
        self.log("val_acc", self.acc, prog_bar=True, on_epoch=True, on_step=True)

    def test_step(self, batch, batch_idx):
        _ = self.shared_step(batch)
        self.log("test_acc", self.acc, prog_bar=True, on_epoch=True, on_step=True)

if __name__ == "__main__":
    torch.set_float32_matmul_precision("high")
    cli = LightningCLI(save_config_callback=None, run=False)

    cli.trainer.fit(cli.model, cli.datamodule)
    cli.trainer.validate(cli.model, cli.datamodule, ckpt_path="checkpoint/resnet_CUB.ckpt")

config file

seed_everything: 42

trainer:
  max_epochs: 1000
  log_every_n_steps: 1
  logger:
    class_path: aim.pytorch_lightning.AimLogger
    init_args:
      run_name: resnet_CUB
  callbacks:
    - class_path: lightning.pytorch.callbacks.ModelCheckpoint
      init_args:
        monitor: val_acc
        dirpath: checkpoints/
        filename: resnet_CUB
        save_top_k: 1
        mode: max
        save_weights_only: True
        enable_version_counter: False
    - class_path: lightning.pytorch.callbacks.EarlyStopping
      init_args:
        monitor: val_acc
        mode: max
        min_delta: 0.001
        patience: 0

model:
  class_path: model.ResNet
  init_args:
    num_classes: 200

data:
  class_path: dataset.CUB
  init_args:
    data_path: ./data
    batch_size: 128

Error messages and logs

In fit, the final acc_epoch is 0.86, while in trainer.validate, it becomes 0.74.

Environment

Current environment * CUDA: - GPU: - NVIDIA GeForce RTX 4090 - available: True - version: 11.8 * Lightning: - lightning: 2.2.5 - lightning-utilities: 0.11.2 - pytorch-lightning: 2.2.2 - torch: 2.1.1 - torchattacks: 3.5.1 - torchaudio: 2.1.1 - torchmetrics: 1.4.0 - torchvision: 0.16.1 * Packages: - absl-py: 2.1.0 - aim: 3.22.0 - aim-ui: 3.22.0 - aimrecords: 0.0.7 - aimrocks: 0.5.2 - aiofiles: 23.2.1 - aiohttp: 3.9.5 - aiosignal: 1.3.1 - alembic: 1.13.0 - annotated-types: 0.6.0 - antlr4-python3-runtime: 4.9.3 - anyio: 3.7.1 - argcomplete: 3.4.0 - argon2-cffi: 23.1.0 - argon2-cffi-bindings: 21.2.0 - arrow: 1.3.0 - asttokens: 2.4.1 - astunparse: 1.6.3 - async-lru: 2.0.4 - async-timeout: 4.0.3 - attrs: 23.1.0 - babel: 2.14.0 - backoff: 2.2.1 - base58: 2.0.1 - beautifulsoup4: 4.12.3 - bitsandbytes: 0.41.0 - black: 24.2.0 - bleach: 6.1.0 - boto3: 1.34.62 - botocore: 1.34.62 - bottleneck: 1.3.5 - brotli: 1.0.9 - cachetools: 5.3.2 - certifi: 2024.7.4 - cffi: 1.16.0 - chardet: 4.0.0 - charset-normalizer: 2.0.4 - click: 8.1.7 - colorama: 0.4.6 - comm: 0.2.1 - contourpy: 1.2.0 - cryptography: 41.0.7 - cycler: 0.11.0 - debugpy: 1.6.7 - decorator: 5.1.1 - defusedxml: 0.7.1 - docopt: 0.6.2 - docstring-parser: 0.16 - exceptiongroup: 1.2.0 - executing: 2.0.1 - fastapi: 0.104.1 - fastjsonschema: 2.19.1 - filelock: 3.13.1 - fonttools: 4.25.0 - fqdn: 1.5.1 - frozenlist: 1.4.1 - fsspec: 2023.10.0 - gmpy2: 2.1.2 - greenlet: 3.0.2 - grpcio: 1.48.2 - h11: 0.14.0 - httpcore: 1.0.4 - httpx: 0.27.0 - hydra-core: 1.3.2 - idna: 2.10 - importlib-metadata: 7.0.1 - importlib-resources: 6.1.0 - ipyflow: 0.0.198 - ipyflow-core: 0.0.198 - ipykernel: 6.29.0 - ipython: 8.18.1 - ipython-genutils: 0.2.0 - ipywidgets: 8.1.2 - isoduration: 20.11.0 - jedi: 0.19.1 - jinja2: 3.1.2 - jmespath: 1.0.1 - joblib: 1.3.2 - json5: 0.9.17 - jsonargparse: 4.29.0 - jsonnet: 0.17.0 - jsonpointer: 2.4 - jsonschema: 4.19.2 - jsonschema-specifications: 2023.12.1 - jupyter: 1.0.0 - jupyter-client: 8.6.0 - jupyter-console: 6.6.3 - jupyter-core: 5.7.1 - jupyter-events: 0.9.0 - jupyter-lsp: 2.2.3 - jupyter-server: 2.12.5 - jupyter-server-terminals: 0.5.2 - jupyterlab: 4.1.2 - jupyterlab-pygments: 0.3.0 - jupyterlab-server: 2.25.3 - jupyterlab-widgets: 3.0.10 - kiwisolver: 1.4.4 - lightning: 2.2.5 - lightning-utilities: 0.11.2 - mako: 1.3.0 - markdown: 3.6 - markdown-it-py: 3.0.0 - markupsafe: 2.1.1 - matplotlib: 3.8.0 - matplotlib-inline: 0.1.6 - mdurl: 0.1.2 - mistune: 3.0.2 - mkl-fft: 1.3.8 - mkl-random: 1.2.4 - mkl-service: 2.4.0 - monotonic: 1.6 - mpmath: 1.3.0 - multidict: 6.0.5 - munkres: 1.1.4 - mypy-extensions: 1.0.0 - nbclassic: 1.0.0 - nbclient: 0.9.0 - nbconvert: 7.16.1 - nbformat: 5.9.2 - nest-asyncio: 1.6.0 - networkx: 3.1 - notebook: 7.1.1 - notebook-shim: 0.2.4 - numexpr: 2.8.7 - numpy: 1.26.2 - omegaconf: 2.3.0 - overrides: 7.7.0 - packaging: 23.1 - pandas: 2.1.1 - pandocfilters: 1.5.1 - parso: 0.8.3 - pathspec: 0.12.1 - pexpect: 4.8.0 - pickleshare: 0.7.5 - pillow: 10.0.1 - pip: 23.3.1 - pipreqs: 0.4.13 - platformdirs: 4.1.0 - ply: 3.11 - pretty-errors: 1.2.25 - prometheus-client: 0.20.0 - prompt-toolkit: 3.0.42 - protobuf: 3.20.3 - psutil: 5.9.1 - ptyprocess: 0.7.0 - pure-eval: 0.2.2 - py3nvml: 0.2.7 - pyccolo: 0.0.52 - pycparser: 2.21 - pydantic: 2.5.2 - pydantic-core: 2.14.5 - pygments: 2.17.2 - pyopenssl: 23.2.0 - pyparsing: 3.0.9 - pyqt5: 5.15.10 - pyqt5-sip: 12.13.0 - pysocks: 1.7.1 - python-dateutil: 2.8.2 - python-json-logger: 2.0.7 - pytorch-lightning: 2.2.2 - pytz: 2023.3.post1 - pyyaml: 6.0.1 - pyzmq: 25.1.2 - qtconsole: 5.5.1 - qtpy: 2.4.1 - referencing: 0.30.2 - requests: 2.25.1 - restrictedpython: 7.0 - rfc3339-validator: 0.1.4 - rfc3986-validator: 0.1.1 - rich: 13.7.1 - rpds-py: 0.10.6 - s3transfer: 0.10.0 - scikit-learn: 1.4.1.post1 - scipy: 1.12.0 - seaborn: 0.12.2 - segment-analytics-python: 2.2.3 - send2trash: 1.8.2 - setuptools: 68.0.0 - sip: 6.7.12 - six: 1.16.0 - sniffio: 1.3.0 - soupsieve: 2.5 - sqlalchemy: 1.4.50 - stack-data: 0.6.2 - starlette: 0.27.0 - sympy: 1.12 - tensorboard: 2.17.0 - tensorboard-data-server: 0.7.0 - tensorboardx: 2.6.2.2 - terminado: 0.18.0 - threadpoolctl: 3.4.0 - tinycss2: 1.2.1 - tomli: 2.0.1 - torch: 2.1.1 - torchattacks: 3.5.1 - torchaudio: 2.1.1 - torchmetrics: 1.4.0 - torchvision: 0.16.1 - tornado: 6.3.3 - tqdm: 4.65.0 - traitlets: 5.14.1 - triton: 2.1.0 - types-python-dateutil: 2.8.19.20240106 - typeshed-client: 2.5.1 - typing-extensions: 4.9.0 - tzdata: 2023.3 - uri-template: 1.3.0 - urllib3: 1.26.18 - uvicorn: 0.24.0.post1 - validators: 0.18.2 - wcwidth: 0.2.13 - webcolors: 1.13 - webencodings: 0.5.1 - websocket-client: 1.7.0 - websockets: 12.0 - werkzeug: 3.0.3 - wheel: 0.41.2 - widgetsnbextension: 4.0.10 - xmltodict: 0.13.0 - yarg: 0.1.9 - yarl: 1.9.4 - zipp: 3.11.0 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.9.18 - release: 6.2.0-39-generic - version: #40-Ubuntu SMP PREEMPT_DYNAMIC Tue Nov 14 14:18:00 UTC 2023

More info

No response

matrix72c commented 2 months ago

I found if I forward acc only in the validate step and test step, the results are the same. However, when using one acc metric in both validate and train, the validate step result is different from single calling validate().