Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
27.93k stars 3.34k forks source link

KeyError in save_hyperparameters while using in a subclass #18405

Open vpozdnyakov opened 1 year ago

vpozdnyakov commented 1 year ago

Bug description

KeyError occurs when I try to save hyperparameters in a subclass which initializes LightningModule.

What version are you seeing the problem on?

v2.0

How to reproduce the bug

from pytorch_lightning import LightningModule

class Base:
    def fit(self):
        pass

class LightningModel(LightningModule):
    def __init__(self, hidden_dim):
        super().__init__()
        self.save_hyperparameters()

class Model(Base):
    def fit(self):
        super().fit()
        self.model = LightningModel(hidden_dim=2)        

model = Model()
model.fit()

Error messages and logs

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
[<ipython-input-3-d10b384b9dc7>](https://localhost:8080/#) in <cell line: 18>()
     16 
     17 model = Model()
---> 18 model.fit()

7 frames
[<ipython-input-3-d10b384b9dc7>](https://localhost:8080/#) in fit(self)
     13     def fit(self):
     14         super().fit()
---> 15         self.model = LightningModel(hidden_dim=2)
     16 
     17 model = Model()

[<ipython-input-3-d10b384b9dc7>](https://localhost:8080/#) in __init__(self, hidden_dim)
      8     def __init__(self, hidden_dim):
      9         super().__init__()
---> 10         self.save_hyperparameters()
     11 
     12 class Model(Base):

[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/core/mixins/hparams_mixin.py](https://localhost:8080/#) in save_hyperparameters(self, ignore, frame, logger, *args)
    109             if current_frame:
    110                 frame = current_frame.f_back
--> 111         save_hyperparameters(self, *args, ignore=ignore, frame=frame)
    112 
    113     def _set_hparams(self, hp: Union[MutableMapping, Namespace, str]) -> None:

[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/utilities/parsing.py](https://localhost:8080/#) in save_hyperparameters(obj, ignore, frame, *args)
    162         from pytorch_lightning.core.mixins import HyperparametersMixin
    163 
--> 164         for local_args in collect_init_args(frame, [], classes=(HyperparametersMixin,)):
    165             init_args.update(local_args)
    166 

[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/utilities/parsing.py](https://localhost:8080/#) in collect_init_args(frame, path_args, inside, classes)
    132         # recursive update
    133         path_args.append(local_args)
--> 134         return collect_init_args(frame.f_back, path_args, inside=True, classes=classes)
    135     if not inside:
    136         return collect_init_args(frame.f_back, path_args, inside=False, classes=classes)

[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/utilities/parsing.py](https://localhost:8080/#) in collect_init_args(frame, path_args, inside, classes)
    128         return path_args
    129 
--> 130     local_self, local_args = _get_init_args(frame)
    131     if "__class__" in local_vars and (not classes or isinstance(local_self, classes)):
    132         # recursive update

[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/utilities/parsing.py](https://localhost:8080/#) in _get_init_args(frame)
     94     exclude_argnames = (*filtered_vars, "__class__", "frame", "frame_args")
     95     # only collect variables that appear in the signature
---> 96     local_args = {k: local_vars[k] for k in init_parameters}
     97     # kwargs_var might be None => raised an error by mypy
     98     if kwargs_var:

[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/utilities/parsing.py](https://localhost:8080/#) in <dictcomp>(.0)
     94     exclude_argnames = (*filtered_vars, "__class__", "frame", "frame_args")
     95     # only collect variables that appear in the signature
---> 96     local_args = {k: local_vars[k] for k in init_parameters}
     97     # kwargs_var might be None => raised an error by mypy
     98     if kwargs_var:

KeyError: 'args'

Environment

Current environment * CUDA: - GPU: None - available: False - version: 11.8 * Lightning: - lightning: 2.0.7 - lightning-cloud: 0.5.37 - lightning-utilities: 0.9.0 - pytorch-lightning: 2.0.7 - torch: 2.0.1+cu118 - torchaudio: 2.0.2+cu118 - torchdata: 0.6.1 - torchmetrics: 1.1.0 - torchsummary: 1.5.1 - torchtext: 0.15.2 - torchvision: 0.15.2+cu118 * Packages: - absl-py: 1.4.0 - aiohttp: 3.8.5 - aiosignal: 1.3.1 - alabaster: 0.7.13 - albumentations: 1.3.1 - altair: 4.2.2 - annotated-types: 0.5.0 - anyio: 3.7.1 - appdirs: 1.4.4 - argon2-cffi: 23.1.0 - argon2-cffi-bindings: 21.2.0 - array-record: 0.4.1 - arrow: 1.2.3 - arviz: 0.15.1 - astropy: 5.3.2 - astunparse: 1.6.3 - async-timeout: 4.0.3 - attrs: 23.1.0 - audioread: 3.0.0 - autograd: 1.6.2 - babel: 2.12.1 - backcall: 0.2.0 - backoff: 2.2.1 - beautifulsoup4: 4.11.2 - bleach: 6.0.0 - blessed: 1.20.0 - blinker: 1.4 - blis: 0.7.10 - blosc2: 2.0.0 - bokeh: 3.2.2 - branca: 0.6.0 - build: 0.10.0 - cachecontrol: 0.13.1 - cachetools: 5.3.1 - catalogue: 2.0.9 - certifi: 2023.7.22 - cffi: 1.15.1 - chardet: 5.2.0 - charset-normalizer: 3.2.0 - chex: 0.1.7 - click: 8.1.7 - click-plugins: 1.1.1 - cligj: 0.7.2 - cloudpickle: 2.2.1 - cmake: 3.27.2 - cmdstanpy: 1.1.0 - colorcet: 3.0.1 - colorlover: 0.3.0 - community: 1.0.0b1 - confection: 0.1.1 - cons: 0.4.6 - contextlib2: 21.6.0 - contourpy: 1.1.0 - convertdate: 2.4.0 - croniter: 1.4.1 - cryptography: 41.0.3 - cufflinks: 0.17.3 - cvxopt: 1.3.2 - cvxpy: 1.3.2 - cycler: 0.11.0 - cymem: 2.0.7 - cython: 0.29.36 - dask: 2023.8.1 - datascience: 0.17.6 - dateutils: 0.6.12 - db-dtypes: 1.1.1 - dbus-python: 1.2.18 - debugpy: 1.6.6 - decorator: 4.4.2 - deepdiff: 6.3.1 - defusedxml: 0.7.1 - distributed: 2023.8.1 - distro: 1.7.0 - dlib: 19.24.2 - dm-tree: 0.1.8 - docutils: 0.18.1 - dopamine-rl: 4.0.6 - duckdb: 0.8.1 - earthengine-api: 0.1.364 - easydict: 1.10 - ecos: 2.0.12 - editdistance: 0.6.2 - en-core-web-sm: 3.6.0 - entrypoints: 0.4 - ephem: 4.1.4 - et-xmlfile: 1.1.0 - etils: 1.4.1 - etuples: 0.3.9 - exceptiongroup: 1.1.3 - fastai: 2.7.12 - fastapi: 0.103.0 - fastcore: 1.5.29 - fastdownload: 0.0.7 - fastjsonschema: 2.18.0 - fastprogress: 1.0.3 - fastrlock: 0.8.1 - filelock: 3.12.2 - fiona: 1.9.4.post1 - firebase-admin: 5.3.0 - flask: 2.2.5 - flatbuffers: 23.5.26 - flax: 0.7.2 - folium: 0.14.0 - fonttools: 4.42.1 - frozendict: 2.3.8 - frozenlist: 1.4.0 - fsspec: 2023.6.0 - future: 0.18.3 - gast: 0.4.0 - gcsfs: 2023.6.0 - gdal: 3.4.3 - gdown: 4.6.6 - gensim: 4.3.1 - geographiclib: 2.0 - geopandas: 0.13.2 - geopy: 2.3.0 - gin-config: 0.5.0 - glob2: 0.7 - google: 2.0.3 - google-api-core: 2.11.1 - google-api-python-client: 2.84.0 - google-auth: 2.17.3 - google-auth-httplib2: 0.1.0 - google-auth-oauthlib: 1.0.0 - google-cloud-bigquery: 3.10.0 - google-cloud-bigquery-connection: 1.12.1 - google-cloud-bigquery-storage: 2.22.0 - google-cloud-core: 2.3.3 - google-cloud-datastore: 2.15.2 - google-cloud-firestore: 2.11.1 - google-cloud-functions: 1.13.2 - google-cloud-language: 2.9.1 - google-cloud-storage: 2.8.0 - google-cloud-translate: 3.11.3 - google-colab: 1.0.0 - google-crc32c: 1.5.0 - google-pasta: 0.2.0 - google-resumable-media: 2.5.0 - googleapis-common-protos: 1.60.0 - googledrivedownloader: 0.4 - graphviz: 0.20.1 - greenlet: 2.0.2 - grpc-google-iam-v1: 0.12.6 - grpcio: 1.57.0 - grpcio-status: 1.48.2 - gspread: 3.4.2 - gspread-dataframe: 3.3.1 - gym: 0.25.2 - gym-notices: 0.0.8 - h11: 0.14.0 - h5netcdf: 1.2.0 - h5py: 3.9.0 - holidays: 0.31 - holoviews: 1.17.1 - html5lib: 1.1 - httpimport: 1.3.1 - httplib2: 0.22.0 - humanize: 4.7.0 - hyperopt: 0.2.7 - idna: 3.4 - imageio: 2.31.1 - imageio-ffmpeg: 0.4.8 - imagesize: 1.4.1 - imbalanced-learn: 0.10.1 - imgaug: 0.4.0 - importlib-metadata: 6.8.0 - importlib-resources: 6.0.1 - imutils: 0.5.4 - inflect: 7.0.0 - iniconfig: 2.0.0 - inquirer: 3.1.3 - intel-openmp: 2023.2.0 - ipykernel: 5.5.6 - ipython: 7.34.0 - ipython-genutils: 0.2.0 - ipython-sql: 0.5.0 - ipywidgets: 7.7.1 - itsdangerous: 2.1.2 - jax: 0.4.14 - jaxlib: 0.4.14+cuda11.cudnn86 - jeepney: 0.7.1 - jieba: 0.42.1 - jinja2: 3.1.2 - joblib: 1.3.2 - jsonpickle: 3.0.2 - jsonschema: 4.19.0 - jsonschema-specifications: 2023.7.1 - jupyter-client: 6.1.12 - jupyter-console: 6.1.0 - jupyter-core: 5.3.1 - jupyter-server: 1.24.0 - jupyterlab-pygments: 0.2.2 - jupyterlab-widgets: 3.0.8 - kaggle: 1.5.16 - keras: 2.12.0 - keyring: 23.5.0 - kiwisolver: 1.4.4 - langcodes: 3.3.0 - launchpadlib: 1.10.16 - lazr.restfulclient: 0.14.4 - lazr.uri: 1.0.6 - lazy-loader: 0.3 - libclang: 16.0.6 - librosa: 0.10.1 - lightgbm: 4.0.0 - lightning: 2.0.7 - lightning-cloud: 0.5.37 - lightning-utilities: 0.9.0 - linkify-it-py: 2.0.2 - lit: 16.0.6 - llvmlite: 0.39.1 - locket: 1.0.0 - logical-unification: 0.4.6 - lunarcalendar: 0.0.9 - lxml: 4.9.3 - markdown: 3.4.4 - markdown-it-py: 3.0.0 - markupsafe: 2.1.3 - matplotlib: 3.7.1 - matplotlib-inline: 0.1.6 - matplotlib-venn: 0.11.9 - mdit-py-plugins: 0.4.0 - mdurl: 0.1.2 - minikanren: 1.0.3 - missingno: 0.5.2 - mistune: 0.8.4 - mizani: 0.9.2 - mkl: 2023.2.0 - ml-dtypes: 0.2.0 - mlxtend: 0.22.0 - more-itertools: 10.1.0 - moviepy: 1.0.3 - mpmath: 1.3.0 - msgpack: 1.0.5 - multidict: 6.0.4 - multipledispatch: 1.0.0 - multitasking: 0.0.11 - murmurhash: 1.0.9 - music21: 9.1.0 - natsort: 8.4.0 - nbclassic: 1.0.0 - nbclient: 0.8.0 - nbconvert: 6.5.4 - nbformat: 5.9.2 - nest-asyncio: 1.5.7 - networkx: 3.1 - nibabel: 4.0.2 - nltk: 3.8.1 - notebook: 6.5.5 - notebook-shim: 0.2.3 - numba: 0.56.4 - numexpr: 2.8.5 - numpy: 1.23.5 - oauth2client: 4.1.3 - oauthlib: 3.2.2 - opencv-contrib-python: 4.8.0.76 - opencv-python: 4.8.0.76 - opencv-python-headless: 4.8.0.76 - openpyxl: 3.1.2 - opt-einsum: 3.3.0 - optax: 0.1.7 - orbax-checkpoint: 0.3.5 - ordered-set: 4.1.0 - osqp: 0.6.2.post8 - packaging: 23.1 - pandas: 1.5.3 - pandas-datareader: 0.10.0 - pandas-gbq: 0.17.9 - pandocfilters: 1.5.0 - panel: 1.2.1 - param: 1.13.0 - parso: 0.8.3 - partd: 1.4.0 - pathlib: 1.0.1 - pathy: 0.10.2 - patsy: 0.5.3 - pexpect: 4.8.0 - pickleshare: 0.7.5 - pillow: 9.4.0 - pip: 23.1.2 - pip-tools: 6.13.0 - platformdirs: 3.10.0 - plotly: 5.15.0 - plotnine: 0.12.2 - pluggy: 1.2.0 - polars: 0.17.3 - pooch: 1.7.0 - portpicker: 1.5.2 - prefetch-generator: 1.0.3 - preshed: 3.0.8 - prettytable: 3.8.0 - proglog: 0.1.10 - progressbar2: 4.2.0 - prometheus-client: 0.17.1 - promise: 2.3 - prompt-toolkit: 3.0.39 - prophet: 1.1.4 - proto-plus: 1.22.3 - protobuf: 3.20.3 - psutil: 5.9.5 - psycopg2: 2.9.7 - ptyprocess: 0.7.0 - py-cpuinfo: 9.0.0 - py4j: 0.10.9.7 - pyarrow: 9.0.0 - pyasn1: 0.5.0 - pyasn1-modules: 0.3.0 - pycocotools: 2.0.7 - pycparser: 2.21 - pyct: 0.5.0 - pydantic: 2.1.1 - pydantic-core: 2.4.0 - pydata-google-auth: 1.8.2 - pydot: 1.4.2 - pydot-ng: 2.0.0 - pydotplus: 2.0.2 - pydrive: 1.3.1 - pydrive2: 1.6.3 - pyerfa: 2.0.0.3 - pygame: 2.5.1 - pygments: 2.16.1 - pygobject: 3.42.1 - pyjwt: 2.3.0 - pymc: 5.7.2 - pymeeus: 0.5.12 - pymystem3: 0.2.0 - pyopengl: 3.1.7 - pyopenssl: 23.2.0 - pyparsing: 3.1.1 - pyproj: 3.6.0 - pyproject-hooks: 1.0.0 - pysocks: 1.7.1 - pytensor: 2.14.2 - pytest: 7.4.0 - python-apt: 0.0.0 - python-dateutil: 2.8.2 - python-editor: 1.0.4 - python-louvain: 0.16 - python-multipart: 0.0.6 - python-slugify: 8.0.1 - python-utils: 3.7.0 - pytorch-lightning: 2.0.7 - pytz: 2023.3 - pyviz-comms: 3.0.0 - pywavelets: 1.4.1 - pyyaml: 6.0.1 - pyzmq: 23.2.1 - qdldl: 0.1.7.post0 - qudida: 0.0.4 - readchar: 4.0.5 - referencing: 0.30.2 - regex: 2023.6.3 - requests: 2.31.0 - requests-oauthlib: 1.3.1 - requirements-parser: 0.5.0 - rich: 13.5.2 - rpds-py: 0.9.2 - rpy2: 3.4.2 - rsa: 4.9 - scikit-image: 0.19.3 - scikit-learn: 1.2.2 - scipy: 1.10.1 - scs: 3.2.3 - seaborn: 0.12.2 - secretstorage: 3.3.1 - send2trash: 1.8.2 - setuptools: 67.7.2 - shapely: 2.0.1 - six: 1.16.0 - sklearn-pandas: 2.2.0 - smart-open: 6.3.0 - sniffio: 1.3.0 - snowballstemmer: 2.2.0 - sortedcontainers: 2.4.0 - soundfile: 0.12.1 - soupsieve: 2.4.1 - soxr: 0.3.6 - spacy: 3.6.1 - spacy-legacy: 3.0.12 - spacy-loggers: 1.0.4 - sphinx: 5.0.2 - sphinxcontrib-applehelp: 1.0.7 - sphinxcontrib-devhelp: 1.0.5 - sphinxcontrib-htmlhelp: 2.0.4 - sphinxcontrib-jsmath: 1.0.1 - sphinxcontrib-qthelp: 1.0.6 - sphinxcontrib-serializinghtml: 1.1.9 - sqlalchemy: 2.0.20 - sqlparse: 0.4.4 - srsly: 2.4.7 - starlette: 0.27.0 - starsessions: 1.3.0 - statsmodels: 0.14.0 - sympy: 1.12 - tables: 3.8.0 - tabulate: 0.9.0 - tbb: 2021.10.0 - tblib: 2.0.0 - tenacity: 8.2.3 - tensorboard: 2.12.3 - tensorboard-data-server: 0.7.1 - tensorflow: 2.12.0 - tensorflow-datasets: 4.9.2 - tensorflow-estimator: 2.12.0 - tensorflow-gcs-config: 2.12.0 - tensorflow-hub: 0.14.0 - tensorflow-io-gcs-filesystem: 0.33.0 - tensorflow-metadata: 1.14.0 - tensorflow-probability: 0.20.1 - tensorstore: 0.1.41 - termcolor: 2.3.0 - terminado: 0.17.1 - text-unidecode: 1.3 - textblob: 0.17.1 - tf-slim: 1.1.0 - thinc: 8.1.12 - threadpoolctl: 3.2.0 - tifffile: 2023.8.12 - tinycss2: 1.2.1 - toml: 0.10.2 - tomli: 2.0.1 - toolz: 0.12.0 - torch: 2.0.1+cu118 - torchaudio: 2.0.2+cu118 - torchdata: 0.6.1 - torchmetrics: 1.1.0 - torchsummary: 1.5.1 - torchtext: 0.15.2 - torchvision: 0.15.2+cu118 - tornado: 6.3.2 - tqdm: 4.66.1 - traitlets: 5.7.1 - triton: 2.0.0 - tweepy: 4.13.0 - typer: 0.9.0 - types-setuptools: 68.1.0.0 - typing-extensions: 4.7.1 - tzlocal: 5.0.1 - uc-micro-py: 1.0.2 - uritemplate: 4.1.1 - urllib3: 2.0.4 - uvicorn: 0.23.2 - vega-datasets: 0.9.0 - wadllib: 1.3.6 - wasabi: 1.1.2 - wcwidth: 0.2.6 - webcolors: 1.13 - webencodings: 0.5.1 - websocket-client: 1.6.2 - websockets: 11.0.3 - werkzeug: 2.3.7 - wheel: 0.41.2 - widgetsnbextension: 3.6.5 - wordcloud: 1.9.2 - wrapt: 1.14.1 - xarray: 2023.7.0 - xarray-einstats: 0.6.0 - xgboost: 1.7.6 - xlrd: 2.0.1 - xyzservices: 2023.7.0 - yarl: 1.9.2 - yellowbrick: 1.5 - yfinance: 0.2.28 - zict: 3.0.0 - zipp: 3.16.2 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.10.12 - release: 5.15.109+ - version: #1 SMP Fri Jun 9 10:57:30 UTC 2023

More info

It works if I change the name of a subclass function to __init__.

amitkparekh commented 11 months ago

I'm getting the same error. However, mine is because I am trying to use Typer/Click as the entrypoint to my application.

I have no idea how to easily solve this apart from not saving hyperparameters at all or not using Typer/Click.

I was able to patch it with:

local_args = {k: local_vars[k] for k in init_parameters if k in local_vars}

I'm not sure if there are any wider ramifications of this approach, and I am not entirely sure what this is doing, but I'd be happy to submit a PR with some guidance?

d-a-bunin commented 11 months ago

I had a similar issue and found a solution. In your case you could try to create a function like:

def create_model():
    return LightningModel(hidden_dim=2)

and use it in Model.fit method.

jamesdeeel commented 5 months ago

I also found this issue too. I think the crux of it is that if you call super() at any point during a method that isnt __init__ then __class__ is added to the local variables. When this happens then the recursive arg parser https://github.com/Lightning-AI/pytorch-lightning/blob/dcb91d53d2133b4db1bf3201b4f965646dea76fd/src/lightning/pytorch/utilities/parsing.py#L136

assumes that Model is a subclass of LightningModel and so it should pull out the initialisation variables from Model.__init__(...) and add them to the dict of variables to be saved. The problem here is that we aren't in Model.__init__(...) so the initialisation variables are not present in the local variables.

I think this

I'm getting the same error. However, mine is because I am trying to use Typer/Click as the entrypoint to my application.

I have no idea how to easily solve this apart from not saving hyperparameters at all or not using Typer/Click.

I was able to patch it with:

local_args = {k: local_vars[k] for k in init_parameters if k in local_vars}

I'm not sure if there are any wider ramifications of this approach, and I am not entirely sure what this is doing, but I'd be happy to submit a PR with some guidance?

is the best approach, but alternatively there could be a check to make sure that we are in the init method before calling local args


if "__class__" in local_vars and (not classes or isinstance(local_self, classes)) and frame.f_back.f_code.co_name == "__init__":