Open rusmux opened 1 year ago
hi, I also faced such issue. My solution is to add before_fit function to your customized CLI class.
def before_fit(self):
tuner = Tuner(self.trainer)
tuner.lr_find(self.model, datamodule=self.datamodule)
In this way, pl will execute configure_optimizers after obtaining the optimal LR. Otherwise if we use LRFinder callback, configure_optimizers will not be executed after finding the optimal LR.
This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions - the Lightning Team!
Bug description
LearningRateFinder
does not update the optimizer if it is defined from the CLI or yaml config file.For example, I define in
train.yaml
:And I set the callback:
At the start, It finds the best learning rate:
But after that, it still uses the learning rate I provided:
I also tried to do it manually like that:
But I had the same result.
How to reproduce the bug
Error messages and logs
Environment
Current environment
``` * CUDA: - GPU: - NVIDIA RTX A4000 - available: True - version: 11.7 * Lightning: - lightning-utilities: 0.6.0.post0 - pytorch-lightning: 1.9.1 - torch: 1.13.1 - torchmetrics: 0.11.1 - torchvision: 0.14.1 * Packages: - aiobotocore: 2.4.2 - aiofiles: 22.1.0 - aiohttp: 3.8.4 - aiohttp-retry: 2.8.3 - aioitertools: 0.11.0 - aiosignal: 1.3.1 - aiosqlite: 0.18.0 - albumentations: 1.3.0 - amqp: 5.1.1 - antlr4-python3-runtime: 4.9.3 - anyio: 3.6.2 - appdirs: 1.4.4 - argcomplete: 2.0.0 - argon2-cffi: 21.3.0 - argon2-cffi-bindings: 21.2.0 - arrow: 1.2.3 - astor: 0.8.1 - asttokens: 2.2.1 - async-timeout: 4.0.2 - asyncssh: 2.13.0 - atpublic: 3.1.1 - attrs: 22.2.0 - babel: 2.11.0 - backcall: 0.2.0 - bandit: 1.7.4 - beautifulsoup4: 4.11.2 - billiard: 3.6.4.0 - bleach: 6.0.0 - boto3: 1.24.59 - botocore: 1.27.59 - celery: 5.2.7 - certifi: 2022.12.7 - cffi: 1.15.1 - cfgv: 3.3.1 - charset-normalizer: 3.0.1 - clearml: 1.9.1 - click: 8.1.3 - click-didyoumean: 0.3.0 - click-plugins: 1.1.1 - click-repl: 0.2.0 - colorama: 0.4.6 - comm: 0.1.2 - configobj: 5.0.8 - contourpy: 1.0.7 - cryptography: 39.0.1 - cycler: 0.11.0 - dacite: 1.8.0 - darglint: 1.8.1 - debugpy: 1.6.6 - decorator: 5.1.1 - defusedxml: 0.7.1 - deprecated: 1.2.13 - dictdiffer: 0.9.0 - dill: 0.3.6 - diskcache: 5.4.0 - distlib: 0.3.6 - distro: 1.8.0 - dnspython: 2.3.0 - docstring-parser: 0.15 - docutils: 0.19 - dpath: 2.1.4 - dulwich: 0.21.2 - dvc: 2.45.0 - dvc-data: 0.40.1 - dvc-http: 2.30.2 - dvc-objects: 0.19.3 - dvc-render: 0.1.2 - dvc-s3: 2.21.0 - dvc-studio-client: 0.4.0 - dvc-task: 0.1.11 - dvclive: 2.0.2 - eradicate: 2.1.0 - eventlet: 0.33.3 - exceptiongroup: 1.1.0 - executing: 1.2.0 - fastjsonschema: 2.16.2 - fiftyone: 0.18.0 - fiftyone-brain: 0.9.2 - fiftyone-db: 0.4.0 - filelock: 3.9.0 - flake8: 4.0.1 - flake8-bandit: 3.0.0 - flake8-broken-line: 0.5.0 - flake8-bugbear: 22.12.6 - flake8-commas: 2.1.0 - flake8-comprehensions: 3.10.1 - flake8-debugger: 4.1.2 - flake8-docstrings: 1.7.0 - flake8-eradicate: 1.4.0 - flake8-isort: 4.2.0 - flake8-polyfill: 1.0.2 - flake8-quotes: 3.3.2 - flake8-rst-docstrings: 0.2.7 - flake8-string-format: 0.3.0 - flatten-dict: 0.4.2 - flufl.lock: 7.1.1 - fonttools: 4.38.0 - fqdn: 1.5.1 - frozenlist: 1.3.3 - fsspec: 2023.1.0 - funcy: 1.18 - furl: 2.1.3 - future: 0.18.3 - gitdb: 4.0.10 - gitpython: 3.1.30 - glob2: 0.7 - grandalf: 0.8 - graphql-core: 3.2.3 - greenlet: 2.0.2 - h11: 0.14.0 - h2: 4.1.0 - hpack: 4.0.0 - httpcore: 0.16.3 - httpx: 0.23.3 - huggingface-hub: 0.12.0 - hydra-core: 1.3.1 - hypercorn: 0.14.3 - hyperframe: 6.0.1 - identify: 2.5.18 - idna: 3.4 - imageio: 2.25.1 - importlib-resources: 5.10.2 - iniconfig: 2.0.0 - ipykernel: 6.21.2 - ipython: 8.10.0 - ipython-genutils: 0.2.0 - ipywidgets: 8.0.4 - isoduration: 20.11.0 - isort: 5.12.0 - iterative-telemetry: 0.0.7 - jedi: 0.18.2 - jinja2: 3.1.2 - jmespath: 1.0.1 - joblib: 1.2.0 - json5: 0.9.11 - jsonargparse: 4.19.0 - jsonpointer: 2.3 - jsonschema: 4.17.3 - jupyter-client: 8.0.2 - jupyter-contrib-core: 0.4.2 - jupyter-contrib-nbextensions: 0.7.0 - jupyter-core: 5.2.0 - jupyter-events: 0.5.0 - jupyter-highlight-selected-word: 0.2.0 - jupyter-nbextensions-configurator: 0.6.1 - jupyter-server: 2.2.1 - jupyter-server-fileid: 0.6.0 - jupyter-server-terminals: 0.4.4 - jupyter-server-ydoc: 0.6.1 - jupyter-ydoc: 0.2.2 - jupyterlab: 3.6.1 - jupyterlab-pygments: 0.2.2 - jupyterlab-server: 2.19.0 - jupyterlab-widgets: 3.0.5 - kaleido: 0.2.1 - kiwisolver: 1.4.4 - kombu: 5.2.4 - lightning-utilities: 0.6.0.post0 - lxml: 4.9.2 - markdown-it-py: 2.1.0 - markupsafe: 2.1.2 - matplotlib: 3.7.0 - matplotlib-inline: 0.1.6 - mccabe: 0.6.1 - mdurl: 0.1.2 - mistune: 2.0.5 - mongoengine: 0.24.2 - motor: 3.1.1 - multidict: 6.0.4 - nanotime: 0.5.2 - nbclassic: 0.5.1 - nbclient: 0.7.2 - nbconvert: 7.2.9 - nbformat: 5.7.3 - ndjson: 0.3.1 - nest-asyncio: 1.5.6 - networkx: 3.0 - nodeenv: 1.7.0 - notebook: 6.5.2 - notebook-shim: 0.2.2 - numpy: 1.24.2 - nvidia-cublas-cu11: 11.10.3.66 - nvidia-cuda-nvrtc-cu11: 11.7.99 - nvidia-cuda-runtime-cu11: 11.7.99 - nvidia-cudnn-cu11: 8.5.0.96 - omegaconf: 2.3.0 - onnx: 1.13.0 - opencv-python-headless: 4.7.0.68 - orderedmultidict: 1.0.1 - orjson: 3.8.6 - packaging: 23.0 - pandas: 1.5.3 - pandocfilters: 1.5.0 - parso: 0.8.3 - pathlib2: 2.3.7.post1 - pathspec: 0.11.0 - patool: 1.12 - pbr: 5.11.1 - pep8-naming: 0.13.2 - pexpect: 4.8.0 - pickleshare: 0.7.5 - pillow: 9.4.0 - pip: 23.0 - platformdirs: 3.0.0 - plotly: 5.13.0 - pluggy: 1.0.0 - pprintpp: 0.4.0 - pre-commit: 2.21.0 - priority: 2.0.0 - prometheus-client: 0.16.0 - prompt-toolkit: 3.0.36 - protobuf: 3.20.3 - psutil: 5.9.4 - ptyprocess: 0.7.0 - pure-eval: 0.2.2 - pycodestyle: 2.8.0 - pycparser: 2.21 - pydocstyle: 6.3.0 - pydot: 1.4.2 - pyflakes: 2.4.0 - pygit2: 1.11.1 - pygments: 2.14.0 - pygtrie: 2.5.0 - pyjwt: 2.4.0 - pymongo: 4.3.3 - pyparsing: 3.0.9 - pyrsistent: 0.19.3 - pytest: 7.2.1 - python-dateutil: 2.8.2 - python-json-logger: 2.0.6 - pytorch-lightning: 1.9.1 - pytz: 2022.7.1 - pytz-deprecation-shim: 0.1.0.post0 - pywavelets: 1.4.1 - pyyaml: 6.0 - pyzmq: 25.0.0 - qudida: 0.0.4 - requests: 2.28.2 - restructuredtext-lint: 1.4.0 - retrying: 1.3.4 - rfc3339-validator: 0.1.4 - rfc3986: 1.5.0 - rfc3986-validator: 0.1.1 - rich: 13.3.1 - ruamel.yaml: 0.17.21 - ruamel.yaml.clib: 0.2.7 - s3fs: 2023.1.0 - s3transfer: 0.6.0 - scikit-image: 0.19.3 - scikit-learn: 1.2.1 - scipy: 1.10.0 - scmrepo: 0.1.9 - send2trash: 1.8.0 - setuptools: 67.3.1 - shortuuid: 1.0.11 - shtab: 1.5.8 - six: 1.16.0 - smmap: 5.0.0 - sniffio: 1.3.0 - snowballstemmer: 2.2.0 - sortedcontainers: 2.4.0 - soupsieve: 2.4 - sqltrie: 0.0.28 - sse-starlette: 0.10.3 - sseclient-py: 1.7.2 - stack-data: 0.6.2 - starlette: 0.20.4 - stevedore: 5.0.0 - strawberry-graphql: 0.138.1 - tabulate: 0.9.0 - tenacity: 8.2.1 - tensorboardx: 2.6 - terminado: 0.17.1 - threadpoolctl: 3.1.0 - tifffile: 2023.2.3 - timm: 0.6.12 - tinycss2: 1.2.1 - toml: 0.10.2 - tomli: 2.0.1 - tomlkit: 0.11.6 - torch: 1.13.1 - torchmetrics: 0.11.1 - torchvision: 0.14.1 - tornado: 6.2 - tqdm: 4.64.1 - traitlets: 5.9.0 - typeshed-client: 2.2.0 - typing-extensions: 4.5.0 - tzdata: 2022.7 - tzlocal: 4.2 - universal-analytics-python3: 1.1.1 - uri-template: 1.2.0 - urllib3: 1.26.14 - vine: 5.0.0 - virtualenv: 20.19.0 - voluptuous: 0.13.1 - voxel51-eta: 0.8.3 - wcwidth: 0.2.6 - webcolors: 1.12 - webencodings: 0.5.1 - websocket-client: 1.5.1 - wemake-python-styleguide: 0.17.0 - wheel: 0.38.4 - widgetsnbextension: 4.0.5 - wrapt: 1.14.1 - wsproto: 1.2.0 - xmltodict: 0.13.0 - y-py: 0.5.5 - yarl: 1.8.2 - ypy-websocket: 0.8.2 - zc.lockfile: 2.0 * System: - OS: Linux - architecture: - 64bit - - processor: - python: 3.10.10 - version: #152-Ubuntu SMP Wed Nov 23 20:19:22 UTC 2022 ```More info
I think, the problem is specific in how and when optimizers and schedulers are instantiated. Because I run the above code, but only for batch size, and it worked as expected:
It used the found batch size in training.
For now, as I understand, the way to use
LearningRateFinder
is to manually defineconfigure_optimizers()
in LightningModule. But this way I can't change the optimizer from the yaml config file.