Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.3k stars 3.38k forks source link

LearningRateFinder not working with CLI optimizers #16787

Open rusmux opened 1 year ago

rusmux commented 1 year ago

Bug description

LearningRateFinder does not update the optimizer if it is defined from the CLI or yaml config file.

For example, I define in train.yaml:

...
optimizer:
  class_path: torch.optim.AdamW
  init_args:
    lr: 1.5e-3
...

And I set the callback:

LearningRateFinder(update_attr=True)

At the start, It finds the best learning rate:

Screenshot 82

But after that, it still uses the learning rate I provided:

Screenshot 83

I also tried to do it manually like that:

Screenshot 84

But I had the same result.

How to reproduce the bug

Define an optimizer in a yaml config file. Add the `LearningRateFinder` callback.

Error messages and logs

# Error messages and logs here please

Environment

Current environment ``` * CUDA: - GPU: - NVIDIA RTX A4000 - available: True - version: 11.7 * Lightning: - lightning-utilities: 0.6.0.post0 - pytorch-lightning: 1.9.1 - torch: 1.13.1 - torchmetrics: 0.11.1 - torchvision: 0.14.1 * Packages: - aiobotocore: 2.4.2 - aiofiles: 22.1.0 - aiohttp: 3.8.4 - aiohttp-retry: 2.8.3 - aioitertools: 0.11.0 - aiosignal: 1.3.1 - aiosqlite: 0.18.0 - albumentations: 1.3.0 - amqp: 5.1.1 - antlr4-python3-runtime: 4.9.3 - anyio: 3.6.2 - appdirs: 1.4.4 - argcomplete: 2.0.0 - argon2-cffi: 21.3.0 - argon2-cffi-bindings: 21.2.0 - arrow: 1.2.3 - astor: 0.8.1 - asttokens: 2.2.1 - async-timeout: 4.0.2 - asyncssh: 2.13.0 - atpublic: 3.1.1 - attrs: 22.2.0 - babel: 2.11.0 - backcall: 0.2.0 - bandit: 1.7.4 - beautifulsoup4: 4.11.2 - billiard: 3.6.4.0 - bleach: 6.0.0 - boto3: 1.24.59 - botocore: 1.27.59 - celery: 5.2.7 - certifi: 2022.12.7 - cffi: 1.15.1 - cfgv: 3.3.1 - charset-normalizer: 3.0.1 - clearml: 1.9.1 - click: 8.1.3 - click-didyoumean: 0.3.0 - click-plugins: 1.1.1 - click-repl: 0.2.0 - colorama: 0.4.6 - comm: 0.1.2 - configobj: 5.0.8 - contourpy: 1.0.7 - cryptography: 39.0.1 - cycler: 0.11.0 - dacite: 1.8.0 - darglint: 1.8.1 - debugpy: 1.6.6 - decorator: 5.1.1 - defusedxml: 0.7.1 - deprecated: 1.2.13 - dictdiffer: 0.9.0 - dill: 0.3.6 - diskcache: 5.4.0 - distlib: 0.3.6 - distro: 1.8.0 - dnspython: 2.3.0 - docstring-parser: 0.15 - docutils: 0.19 - dpath: 2.1.4 - dulwich: 0.21.2 - dvc: 2.45.0 - dvc-data: 0.40.1 - dvc-http: 2.30.2 - dvc-objects: 0.19.3 - dvc-render: 0.1.2 - dvc-s3: 2.21.0 - dvc-studio-client: 0.4.0 - dvc-task: 0.1.11 - dvclive: 2.0.2 - eradicate: 2.1.0 - eventlet: 0.33.3 - exceptiongroup: 1.1.0 - executing: 1.2.0 - fastjsonschema: 2.16.2 - fiftyone: 0.18.0 - fiftyone-brain: 0.9.2 - fiftyone-db: 0.4.0 - filelock: 3.9.0 - flake8: 4.0.1 - flake8-bandit: 3.0.0 - flake8-broken-line: 0.5.0 - flake8-bugbear: 22.12.6 - flake8-commas: 2.1.0 - flake8-comprehensions: 3.10.1 - flake8-debugger: 4.1.2 - flake8-docstrings: 1.7.0 - flake8-eradicate: 1.4.0 - flake8-isort: 4.2.0 - flake8-polyfill: 1.0.2 - flake8-quotes: 3.3.2 - flake8-rst-docstrings: 0.2.7 - flake8-string-format: 0.3.0 - flatten-dict: 0.4.2 - flufl.lock: 7.1.1 - fonttools: 4.38.0 - fqdn: 1.5.1 - frozenlist: 1.3.3 - fsspec: 2023.1.0 - funcy: 1.18 - furl: 2.1.3 - future: 0.18.3 - gitdb: 4.0.10 - gitpython: 3.1.30 - glob2: 0.7 - grandalf: 0.8 - graphql-core: 3.2.3 - greenlet: 2.0.2 - h11: 0.14.0 - h2: 4.1.0 - hpack: 4.0.0 - httpcore: 0.16.3 - httpx: 0.23.3 - huggingface-hub: 0.12.0 - hydra-core: 1.3.1 - hypercorn: 0.14.3 - hyperframe: 6.0.1 - identify: 2.5.18 - idna: 3.4 - imageio: 2.25.1 - importlib-resources: 5.10.2 - iniconfig: 2.0.0 - ipykernel: 6.21.2 - ipython: 8.10.0 - ipython-genutils: 0.2.0 - ipywidgets: 8.0.4 - isoduration: 20.11.0 - isort: 5.12.0 - iterative-telemetry: 0.0.7 - jedi: 0.18.2 - jinja2: 3.1.2 - jmespath: 1.0.1 - joblib: 1.2.0 - json5: 0.9.11 - jsonargparse: 4.19.0 - jsonpointer: 2.3 - jsonschema: 4.17.3 - jupyter-client: 8.0.2 - jupyter-contrib-core: 0.4.2 - jupyter-contrib-nbextensions: 0.7.0 - jupyter-core: 5.2.0 - jupyter-events: 0.5.0 - jupyter-highlight-selected-word: 0.2.0 - jupyter-nbextensions-configurator: 0.6.1 - jupyter-server: 2.2.1 - jupyter-server-fileid: 0.6.0 - jupyter-server-terminals: 0.4.4 - jupyter-server-ydoc: 0.6.1 - jupyter-ydoc: 0.2.2 - jupyterlab: 3.6.1 - jupyterlab-pygments: 0.2.2 - jupyterlab-server: 2.19.0 - jupyterlab-widgets: 3.0.5 - kaleido: 0.2.1 - kiwisolver: 1.4.4 - kombu: 5.2.4 - lightning-utilities: 0.6.0.post0 - lxml: 4.9.2 - markdown-it-py: 2.1.0 - markupsafe: 2.1.2 - matplotlib: 3.7.0 - matplotlib-inline: 0.1.6 - mccabe: 0.6.1 - mdurl: 0.1.2 - mistune: 2.0.5 - mongoengine: 0.24.2 - motor: 3.1.1 - multidict: 6.0.4 - nanotime: 0.5.2 - nbclassic: 0.5.1 - nbclient: 0.7.2 - nbconvert: 7.2.9 - nbformat: 5.7.3 - ndjson: 0.3.1 - nest-asyncio: 1.5.6 - networkx: 3.0 - nodeenv: 1.7.0 - notebook: 6.5.2 - notebook-shim: 0.2.2 - numpy: 1.24.2 - nvidia-cublas-cu11: 11.10.3.66 - nvidia-cuda-nvrtc-cu11: 11.7.99 - nvidia-cuda-runtime-cu11: 11.7.99 - nvidia-cudnn-cu11: 8.5.0.96 - omegaconf: 2.3.0 - onnx: 1.13.0 - opencv-python-headless: 4.7.0.68 - orderedmultidict: 1.0.1 - orjson: 3.8.6 - packaging: 23.0 - pandas: 1.5.3 - pandocfilters: 1.5.0 - parso: 0.8.3 - pathlib2: 2.3.7.post1 - pathspec: 0.11.0 - patool: 1.12 - pbr: 5.11.1 - pep8-naming: 0.13.2 - pexpect: 4.8.0 - pickleshare: 0.7.5 - pillow: 9.4.0 - pip: 23.0 - platformdirs: 3.0.0 - plotly: 5.13.0 - pluggy: 1.0.0 - pprintpp: 0.4.0 - pre-commit: 2.21.0 - priority: 2.0.0 - prometheus-client: 0.16.0 - prompt-toolkit: 3.0.36 - protobuf: 3.20.3 - psutil: 5.9.4 - ptyprocess: 0.7.0 - pure-eval: 0.2.2 - pycodestyle: 2.8.0 - pycparser: 2.21 - pydocstyle: 6.3.0 - pydot: 1.4.2 - pyflakes: 2.4.0 - pygit2: 1.11.1 - pygments: 2.14.0 - pygtrie: 2.5.0 - pyjwt: 2.4.0 - pymongo: 4.3.3 - pyparsing: 3.0.9 - pyrsistent: 0.19.3 - pytest: 7.2.1 - python-dateutil: 2.8.2 - python-json-logger: 2.0.6 - pytorch-lightning: 1.9.1 - pytz: 2022.7.1 - pytz-deprecation-shim: 0.1.0.post0 - pywavelets: 1.4.1 - pyyaml: 6.0 - pyzmq: 25.0.0 - qudida: 0.0.4 - requests: 2.28.2 - restructuredtext-lint: 1.4.0 - retrying: 1.3.4 - rfc3339-validator: 0.1.4 - rfc3986: 1.5.0 - rfc3986-validator: 0.1.1 - rich: 13.3.1 - ruamel.yaml: 0.17.21 - ruamel.yaml.clib: 0.2.7 - s3fs: 2023.1.0 - s3transfer: 0.6.0 - scikit-image: 0.19.3 - scikit-learn: 1.2.1 - scipy: 1.10.0 - scmrepo: 0.1.9 - send2trash: 1.8.0 - setuptools: 67.3.1 - shortuuid: 1.0.11 - shtab: 1.5.8 - six: 1.16.0 - smmap: 5.0.0 - sniffio: 1.3.0 - snowballstemmer: 2.2.0 - sortedcontainers: 2.4.0 - soupsieve: 2.4 - sqltrie: 0.0.28 - sse-starlette: 0.10.3 - sseclient-py: 1.7.2 - stack-data: 0.6.2 - starlette: 0.20.4 - stevedore: 5.0.0 - strawberry-graphql: 0.138.1 - tabulate: 0.9.0 - tenacity: 8.2.1 - tensorboardx: 2.6 - terminado: 0.17.1 - threadpoolctl: 3.1.0 - tifffile: 2023.2.3 - timm: 0.6.12 - tinycss2: 1.2.1 - toml: 0.10.2 - tomli: 2.0.1 - tomlkit: 0.11.6 - torch: 1.13.1 - torchmetrics: 0.11.1 - torchvision: 0.14.1 - tornado: 6.2 - tqdm: 4.64.1 - traitlets: 5.9.0 - typeshed-client: 2.2.0 - typing-extensions: 4.5.0 - tzdata: 2022.7 - tzlocal: 4.2 - universal-analytics-python3: 1.1.1 - uri-template: 1.2.0 - urllib3: 1.26.14 - vine: 5.0.0 - virtualenv: 20.19.0 - voluptuous: 0.13.1 - voxel51-eta: 0.8.3 - wcwidth: 0.2.6 - webcolors: 1.12 - webencodings: 0.5.1 - websocket-client: 1.5.1 - wemake-python-styleguide: 0.17.0 - wheel: 0.38.4 - widgetsnbextension: 4.0.5 - wrapt: 1.14.1 - wsproto: 1.2.0 - xmltodict: 0.13.0 - y-py: 0.5.5 - yarl: 1.8.2 - ypy-websocket: 0.8.2 - zc.lockfile: 2.0 * System: - OS: Linux - architecture: - 64bit - - processor: - python: 3.10.10 - version: #152-Ubuntu SMP Wed Nov 23 20:19:22 UTC 2022 ```

More info

I think, the problem is specific in how and when optimizers and schedulers are instantiated. Because I run the above code, but only for batch size, and it worked as expected:

Screenshot 85

It used the found batch size in training.

For now, as I understand, the way to use LearningRateFinder is to manually define configure_optimizers() in LightningModule. But this way I can't change the optimizer from the yaml config file.

weicao1990 commented 1 year ago

hi, I also faced such issue. My solution is to add before_fit function to your customized CLI class.

def before_fit(self):
    tuner = Tuner(self.trainer)
    tuner.lr_find(self.model, datamodule=self.datamodule)

In this way, pl will execute configure_optimizers after obtaining the optimal LR. Otherwise if we use LRFinder callback, configure_optimizers will not be executed after finding the optimal LR.

stale[bot] commented 1 year ago

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions - the Lightning Team!