Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
27.92k stars 3.34k forks source link

EarlyStopping interfered by LearningRateFinder #19575

Open zhf231298 opened 6 months ago

zhf231298 commented 6 months ago

Bug description

When EarlyStopping is used together with LearningRateFinder, the early stopping check is triggered $n$ steps before the validation, where $n$ is the number of steps executed by the learning rate finder. This could be an issue when the early stopping check is based on a validation metric, as at the time of early stopping check the validation metric has not been computed yet.

What version are you seeing the problem on?

v2.2, master

How to reproduce the bug

callbacks = [
    EarlyStopping(monitor="val/loss", patience=10),
    LearningRateFinder()
]
trainer = Trainer(
    log_every_n_steps=1,
    gradient_clip_val=1,
    callbacks=callbacks,
)

Error messages and logs

RuntimeError: Early stopping conditioned on metric `val/loss` which is not available. Pass in or modify your `EarlyStopping` callback to use any of the following: `train/loss`, `train/raw_loss`, `train/mae`, `train/raw_mae`, `train/mae_improvement`

Environment

Current environment ``` * CUDA: - GPU: - NVIDIA GeForce RTX 4080 - available: True - version: 12.1 * Lightning: - lightning: 2.2.0.post0 - lightning-utilities: 0.10.1 - pytorch-lightning: 2.2.0.post0 - torch: 2.2.0 - torchmetrics: 1.3.1 * Packages: - absl-py: 2.1.0 - aiohttp: 3.9.3 - aiosignal: 1.3.1 - arrow: 1.3.0 - attrs: 23.2.0 - boto3: 1.34.25 - botocore: 1.34.25 - bravado: 11.0.3 - bravado-core: 6.1.1 - certifi: 2024.2.2 - cffi: 1.16.0 - charset-normalizer: 3.3.2 - click: 8.1.7 - contourpy: 1.2.0 - cycler: 0.12.1 - filelock: 3.13.1 - flatbuffers: 23.5.26 - fonttools: 4.47.2 - fqdn: 1.5.1 - frozenlist: 1.4.1 - fsspec: 2024.2.0 - future: 1.0.0 - gitdb: 4.0.11 - gitpython: 3.1.42 - idna: 3.6 - iniconfig: 2.0.0 - isoduration: 20.11.0 - jinja2: 3.1.3 - jmespath: 1.0.1 - jsonpointer: 2.4 - jsonref: 1.1.0 - jsonschema: 4.21.1 - jsonschema-specifications: 2023.12.1 - kiwisolver: 1.4.5 - lightning: 2.2.0.post0 - lightning-utilities: 0.10.1 - markupsafe: 2.1.5 - matplotlib: 3.8.2 - mediapipe: 0.10.9 - monotonic: 1.6 - mpmath: 1.3.0 - msgpack: 1.0.7 - multidict: 6.0.5 - mypy: 1.8.0 - mypy-extensions: 1.0.0 - neptune: 1.9.1 - networkx: 3.2.1 - numpy: 1.26.3 - nvidia-cublas-cu12: 12.1.3.1 - nvidia-cuda-cupti-cu12: 12.1.105 - nvidia-cuda-nvrtc-cu12: 12.1.105 - nvidia-cuda-runtime-cu12: 12.1.105 - nvidia-cudnn-cu12: 8.9.2.26 - nvidia-cufft-cu12: 11.0.2.54 - nvidia-curand-cu12: 10.3.2.106 - nvidia-cusolver-cu12: 11.4.5.107 - nvidia-cusparse-cu12: 12.1.0.106 - nvidia-nccl-cu12: 2.19.3 - nvidia-nvjitlink-cu12: 12.3.101 - nvidia-nvtx-cu12: 12.1.105 - oauthlib: 3.2.2 - opencv-contrib-python: 4.9.0.80 - opencv-python: 4.9.0.80 - packaging: 23.2 - pandas: 2.2.0 - pandas-stubs: 2.1.4.231227 - pillow: 10.2.0 - pip: 23.2.1 - pluggy: 1.4.0 - protobuf: 3.20.3 - psutil: 5.9.8 - pyarrow: 15.0.0 - pycparser: 2.21 - pyjwt: 2.8.0 - pyparsing: 3.1.1 - pytest: 7.4.4 - python-dateutil: 2.8.2 - pytorch-lightning: 2.2.0.post0 - pytz: 2023.3.post1 - pyyaml: 6.0.1 - referencing: 0.33.0 - requests: 2.31.0 - requests-oauthlib: 1.3.1 - rfc3339-validator: 0.1.4 - rfc3986-validator: 0.1.1 - rpds-py: 0.18.0 - ruff: 0.1.14 - s3transfer: 0.10.0 - scipy: 1.12.0 - setuptools: 69.1.0 - simplejson: 3.19.2 - six: 1.16.0 - smmap: 5.0.1 - sounddevice: 0.4.6 - swagger-spec-validator: 3.0.3 - sympy: 1.12 - torch: 2.2.0 - torchmetrics: 1.3.1 - tqdm: 4.66.1 - triton: 2.2.0 - types-python-dateutil: 2.8.19.20240106 - types-pytz: 2023.3.1.1 - types-pyyaml: 6.0.12.12 - types-tqdm: 4.66.0.20240106 - typing-extensions: 4.9.0 - tzdata: 2023.4 - uri-template: 1.3.0 - urllib3: 2.0.7 - webcolors: 1.13 - websocket-client: 1.7.0 - yarl: 1.9.4 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.11.7 - release: 5.15.0-97-generic - version: #107-Ubuntu SMP Wed Feb 7 13:26:48 UTC 2024```

More info

I have not tried to reproduce this error with other callback options, but I think that it could potentially cause the same issue with other callbacks that runs the network before the actual start of the fitting.

famura commented 2 months ago

I just encountered the same issue using pytorch-lightning version 2.2.4.

It seems like the learning rate finder iterations are counting towards some counter that triggers the on_advance_end callback which then runs into a problem when the early stopping callback can't find its metric because we only log it during the validation step and not the training step.