Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
27.95k stars 3.34k forks source link

save_hyperparameter incorrectly infers parameters from superclass #19596

Open klieret opened 6 months ago

klieret commented 6 months ago

Bug description

Given a model with a submodel, both of which call save_hyperparameters, hyperparameters of the submodel that share a name with the main model are overwritten.

What version are you seeing the problem on?

v2.2

How to reproduce the bug

from pytorch_lightning.core.mixins.hparams_mixin import HyperparametersMixin

class Submodel(HyperparametersMixin):
    def __init__(self, hparam: int):
        super().__init__()
        self.save_hyperparameters()

class Model(HyperparametersMixin):
    def __init__(self, hparam: int):
        super().__init__()
        self.submodel = Submodel(a=3)
        self.save_hyperparameters()

model = Model(hparam=5)
print(model.hparams)
print(model.submodel.hparams)

Expectation: model.hparams.hparam == 5, model.submodel.hparams.hparam == 3 **Reality: model.hparams.hparam == 5, model.submodel.hparams.hparam == 5

Error messages and logs

No response

Environment

Current environment * CUDA: - GPU: None - available: False - version: None * Lightning: - lightning: 2.2.1 - lightning-utilities: 0.8.0 - pytorch-lightning: 2.2.1 - torch: 2.0.1 - torch-cluster: 1.6.1 - torch-geometric: 2.3.1 - torchmetrics: 1.0.0 * Packages: - accessible-pygments: 0.0.4 - aiohttp: 3.8.4 - aiosignal: 1.3.1 - alabaster: 0.7.13 - alembic: 1.11.1 - antlr4-python3-runtime: 4.9.3 - anyio: 3.7.0 - appdirs: 1.4.4 - appnope: 0.1.3 - argon2-cffi: 21.3.0 - argon2-cffi-bindings: 21.2.0 - arrow: 1.2.3 - astroid: 3.0.0 - asttokens: 2.2.1 - async-lru: 2.0.2 - async-timeout: 4.0.2 - attrs: 23.1.0 - awkward: 2.2.4 - awkward-cpp: 17 - babel: 2.12.1 - backcall: 0.2.0 - backports.functools-lru-cache: 1.6.4 - beautifulsoup4: 4.12.2 - bleach: 6.0.0 - certifi: 2023.5.7 - cffi: 1.15.1 - cfgv: 3.3.1 - charset-normalizer: 3.1.0 - click: 8.1.3 - cmaes: 0.9.1 - codespell: 2.2.6 - colorama: 0.4.6 - colorlog: 6.7.0 - comm: 0.1.3 - commonmark: 0.9.1 - contourpy: 1.0.7 - coolname: 2.2.0 - coverage: 7.2.7 - cycler: 0.11.0 - debugpy: 1.6.7 - decorator: 5.1.1 - defusedxml: 0.7.1 - dill: 0.3.7 - diskcache: 5.6.3 - distlib: 0.3.6 - docker-pycreds: 0.4.0 - docstring-parser: 0.15 - docutils: 0.19 - entrypoints: 0.4 - exceptiongroup: 1.1.1 - executing: 1.2.0 - fastjsonschema: 2.17.1 - filelock: 3.12.0 - flit-core: 3.9.0 - fonttools: 4.39.4 - fqdn: 1.5.1 - frozenlist: 1.3.3 - fsspec: 2023.6.0 - gitdb: 4.0.10 - gitpython: 3.1.31 - gmpy2: 2.1.2 - gnn-tracking: 0.0.1 - gnntrack: 0.0.1 - greenlet: 2.0.2 - hpo2: 0.1.0 - hsfparana: 0.1.0 - hydra-core: 1.3.2 - identify: 2.5.24 - idna: 3.4 - imagesize: 1.4.1 - importlib-metadata: 6.6.0 - importlib-resources: 5.12.0 - iniconfig: 2.0.0 - ipykernel: 6.23.1 - ipython: 8.14.0 - ipython-genutils: 0.2.0 - ipywidgets: 8.0.6 - isoduration: 20.11.0 - isort: 5.12.0 - jedi: 0.18.2 - jinja2: 3.1.2 - joblib: 1.2.0 - json5: 0.9.14 - jsonargparse: 4.21.2 - jsonpointer: 2.4 - jsonschema: 4.17.3 - jupyter: 1.0.0 - jupyter-client: 8.2.0 - jupyter-console: 6.6.3 - jupyter-core: 5.3.0 - jupyter-events: 0.6.3 - jupyter-lsp: 2.2.0 - jupyter-server: 2.6.0 - jupyter-server-terminals: 0.4.4 - jupyterlab: 4.0.2 - jupyterlab-pygments: 0.2.2 - jupyterlab-server: 2.23.0 - jupyterlab-widgets: 3.0.7 - kiwisolver: 1.4.4 - lazy-object-proxy: 1.9.0 - lightning: 2.2.1 - lightning-utilities: 0.8.0 - llvmlite: 0.40.1 - mako: 1.2.4 - markdown-it-py: 3.0.0 - markupsafe: 2.1.3 - matplotlib: 3.7.1 - matplotlib-inline: 0.1.6 - mccabe: 0.7.0 - mdmm: 0.1.3 - mdurl: 0.1.2 - mistune: 2.0.5 - mplhep: 0.3.28 - mplhep-data: 0.0.3 - mpmath: 1.3.0 - msgpack: 1.0.5 - multidict: 6.0.4 - nbclassic: 1.0.0 - nbclient: 0.8.0 - nbconvert: 7.4.0 - nbformat: 5.9.0 - nest-asyncio: 1.5.6 - networkx: 3.1 - nodeenv: 1.8.0 - notebook: 6.5.4 - notebook-shim: 0.2.3 - numba: 0.57.1 - numpy: 1.24.4 - object-condensation: 0.1.dev20+gf5708c7 - omegaconf: 2.3.0 - optuna: 3.2.0 - overrides: 7.3.1 - packaging: 23.1 - pandas: 2.0.2 - pandocfilters: 1.5.0 - parso: 0.8.3 - pathtools: 0.1.2 - pexpect: 4.8.0 - pickleshare: 0.7.5 - pillow: 9.5.0 - pip: 23.1.2 - pkgutil-resolve-name: 1.3.10 - platformdirs: 3.5.1 - pluggy: 1.0.0 - pooch: 1.7.0 - pre-commit: 3.3.2 - prometheus-client: 0.17.0 - prompt-toolkit: 3.0.38 - protobuf: 3.20.3 - psutil: 5.9.5 - ptyprocess: 0.7.0 - pure-eval: 0.2.2 - pyarrow: 12.0.1 - pycparser: 2.21 - pydata-sphinx-theme: 0.13.3 - pygments: 2.15.1 - pylint: 3.0.0 - pyobjc-core: 9.2 - pyobjc-framework-cocoa: 9.2 - pyparsing: 3.0.9 - pyrsistent: 0.19.3 - pysocks: 1.7.1 - pytest: 7.4.0 - pytest-cov: 4.1.0 - pytest-cover: 3.0.0 - pytest-coverage: 0.0 - python-dateutil: 2.8.2 - python-frontmatter: 1.0.0 - python-json-logger: 2.0.7 - pytorch-lightning: 2.2.1 - pytz: 2023.3 - pyyaml: 6.0 - pyzmq: 25.1.0 - ray: 2.5.1 - recommonmark: 0.7.1 - requests: 2.31.0 - rfc3339-validator: 0.1.4 - rfc3986-validator: 0.1.1 - rich: 13.4.2 - ruff: 0.0.276 - scienceplots: 2.1.1 - scikit-learn: 1.2.2 - scipy: 1.10.1 - send2trash: 1.8.2 - sentry-sdk: 1.21.1 - setproctitle: 1.3.2 - setuptools: 67.7.2 - six: 1.16.0 - smmap: 3.0.5 - sniffio: 1.3.0 - snowballstemmer: 2.2.0 - soupsieve: 2.3.2.post1 - sphinx: 6.2.1 - sphinx-autoapi: 2.1.0 - sphinx-book-theme: 1.0.1 - sphinxcontrib-applehelp: 1.0.4 - sphinxcontrib-devhelp: 1.0.2 - sphinxcontrib-htmlhelp: 2.0.1 - sphinxcontrib-jsmath: 1.0.1 - sphinxcontrib-qthelp: 1.0.3 - sphinxcontrib-serializinghtml: 1.1.5 - sqlalchemy: 2.0.15 - stack-data: 0.6.2 - sympy: 1.12 - tabulate: 0.9.0 - tensorboardx: 2.6 - terminado: 0.17.1 - threadpoolctl: 3.1.0 - tinycss2: 1.2.1 - tomlkit: 0.12.1 - torch: 2.0.1 - torch-cluster: 1.6.1 - torch-geometric: 2.3.1 - torchmetrics: 1.0.0 - tornado: 6.3.2 - tqdm: 4.65.0 - trackml: 3 - traitlets: 5.9.0 - types-markupsafe: 1.1.10 - typeshed-client: 2.3.0 - typing-extensions: 4.6.3 - typing-utils: 0.1.0 - tzdata: 2023.3 - uhi: 0.3.3 - unidecode: 1.3.6 - uproot: 5.0.9 - uri-template: 1.3.0 - urllib3: 2.0.3 - virtualenv: 20.23.0 - wandb: 0.16.3 - wandb-osh: 1.0.4 - wcwidth: 0.2.6 - webcolors: 1.13 - webencodings: 0.5.1 - websocket-client: 1.5.2 - wheel: 0.40.0 - widgetsnbextension: 4.0.7 - wrapt: 1.15.0 - yarl: 1.9.2 - zipp: 3.15.0 * System: - OS: Darwin - architecture: - 64bit - - processor: arm - python: 3.11.3 - release: 23.2.0 - version: Darwin Kernel Version 23.2.0: Wed Nov 15 21:53:18 PST 2023; root:xnu-10002.61.3~2/RELEASE_ARM64_T6000

More info

No response

awaelchli commented 6 months ago

Hey @klieret

The save_hyperparameters() feature is often misunderstood in the context of inheritance. The real intention of how it is supposed to work is to capture just the input arguments, and only the input arguments of the model you instantiate. That means in your case

model = Model(hparam=5)

the only thing that self.save_hyperparameters() should save is "hparam=5". Because that's the parameter you would need to re-instantiate your model.

The fact that save_hyperparameters also captures the parameters in the other modules (submodules or super calls) is leading to a lot of confusion and ill-defined behaviors, and is a result of a bad design that is now so old it is impossible to "ifx it" without breaking changes.

TL;DR The good news: We know how to fix the situation. The bad news: It will be a while until this can be fixed because we need to first go through a deprecation of some bad features in save_hyperparameters() like #10375. After removing this "bad code" it will be possible to actually save the right things.

You are not the first to report this, so I think we need to start moving in that direction.