Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.26k stars 3.38k forks source link

LightningCLI: incorrect default value of kwarg used #18616

Open adamjstewart opened 1 year ago

adamjstewart commented 1 year ago

Bug description

I'm trying to port TorchGeo to LightningCLI: https://github.com/microsoft/torchgeo/pull/1559

99% of things are great! Unfortunately I'm still stuck on the last 1%. If you look at the failing tests, you can see two issues:

So2SatDataModule

When parsing So2SatDataModule, jsonargparse notices that kwargs may contain a version parameter. If it doesn't, the default is "2". However, jsonargparse detects the type as int and uses a default value of 2. Since "2" is in the dict but 2 isn't, the dataset fails to initialize.

SeasonalContrastS2DataModule

When parsing SeasonalContrastS2DataModule, jsonargparse notices that kwargs may contain a bands parameter. If it doesn't, the default is ["B4", "B3", "B2"]. However, jsonargparse detects the type as Any and uses a default value of None. Obviously, None is invalid.

It's still not clear to me if this is a lightning issue or a jsonargparse issue, but I'm surprised that it's trying to poke into my kwargs at all. Isn't that the point of dict_kwargs, for things that can't be type checked? Even more surprising is that even though my seasonal contrast and so2sat config.yaml files don't use these parameters, lightning is adding a default value anyway.

Is this a bug in jsonarpgarse or lightning? Is there any way to disable this aggressive kwarg parsing or use of default values?

What version are you seeing the problem on?

v2.0

How to reproduce the bug

No response

Error messages and logs

No response

Environment

Current environment * CUDA: - GPU: None - available: False - version: None * Lightning: - efficientnet-pytorch: 0.7.1 - lightning: 2.0.9 - lightning-cloud: 0.5.38 - lightning-utilities: 0.8.0 - pytorch-lightning: 2.0.0 - pytorch-sphinx-theme: 0.0.24 - segmentation-models-pytorch: 0.3.3 - torch: 2.0.1 - torchmetrics: 1.1.1 - torchvision: 0.15.2 * Packages: - absl-py: 1.4.0 - aenum: 3.1.12 - affine: 2.1.0 - aiohttp: 3.8.4 - aiosignal: 1.2.0 - alabaster: 0.7.13 - antlr4-python3-runtime: 4.9.3 - anyio: 3.6.2 - appdirs: 1.4.4 - appnope: 0.1.3 - argon2-cffi: 21.3.0 - argon2-cffi-bindings: 21.2.0 - arrow: 1.2.3 - asttokens: 2.2.1 - astunparse: 1.6.3 - async-lru: 1.0.3 - async-timeout: 4.0.2 - attrs: 23.1.0 - babel: 2.12.1 - backcall: 0.2.0 - backoff: 2.2.1 - beautifulsoup4: 4.12.2 - black: 23.9.1 - bleach: 6.0.0 - blessed: 1.19.0 - bottleneck: 1.3.7 - build: 1.0.3 - cachetools: 5.2.0 - cartopy: 0.22.0 - certifi: 2023.5.7 - cffi: 1.15.1 - cftime: 1.0.3.4 - charset-normalizer: 3.1.0 - click: 8.1.3 - click-plugins: 1.1.1 - cligj: 0.7.2 - cmocean: 2.0 - colorama: 0.4.6 - comm: 0.1.3 - contourpy: 1.0.7 - coverage: 7.2.6 - croniter: 1.3.8 - cycler: 0.11.0 - cython: 0.29.36 - dateutils: 0.6.12 - debugpy: 1.6.7 - decorator: 5.1.1 - deepdiff: 6.3.0 - defusedxml: 0.7.1 - docstring-parser: 0.15 - docutils: 0.18.1 - editables: 0.3 - efficientnet-pytorch: 0.7.1 - einops: 0.6.1 - et-xmlfile: 1.0.1 - executing: 1.2.0 - fastapi: 0.98.0 - fastjsonschema: 2.16.3 - filelock: 3.12.0 - fiona: 1.9.4 - flake8: 6.1.0 - flit-core: 3.9.0 - fonttools: 4.39.4 - fqdn: 1.5.1 - frozenlist: 1.3.1 - fsspec: 2023.1.0 - gdal: 3.7.2 - geocube: 0.3.2 - geopandas: 0.11.1 - gevent: 23.7.0 - google-auth: 2.20.0 - google-auth-oauthlib: 0.5.2 - greenlet: 2.0.2 - grpcio: 1.52.0 - h11: 0.13.0 - h5py: 3.8.0 - hatch-jupyter-builder: 0.8.3 - hatchling: 1.17.0 - huggingface-hub: 0.14.1 - hydra-core: 1.3.1 - idna: 3.4 - imagesize: 1.4.1 - importlib-metadata: 6.6.0 - importlib-resources: 5.12.0 - iniconfig: 2.0.0 - inquirer: 3.1.3 - ipykernel: 6.23.1 - ipython: 8.14.0 - ipywidgets: 8.0.2 - isoduration: 20.11.0 - isort: 5.12.0 - itsdangerous: 2.1.2 - jaraco.classes: 3.2.3 - jedi: 0.18.2 - jinja2: 3.0.3 - joblib: 1.2.0 - json5: 0.9.14 - jsonargparse: 4.19.0 - jsonpointer: 2.0 - jsonschema: 4.17.3 - jupyter-client: 8.2.0 - jupyter-core: 5.3.0 - jupyter-events: 0.6.3 - jupyter-lsp: 2.2.0 - jupyter-server: 2.6.0 - jupyter-server-terminals: 0.4.4 - jupyterlab: 4.0.1 - jupyterlab-pygments: 0.2.2 - jupyterlab-server: 2.22.1 - jupyterlab-widgets: 3.0.3 - keyring: 23.13.1 - kiwisolver: 1.4.4 - kornia: 0.7.0 - laspy: 2.2.0 - lightly: 1.4.18 - lightly-utils: 0.0.2 - lightning: 2.0.9 - lightning-cloud: 0.5.38 - lightning-utilities: 0.8.0 - markdown: 3.4.1 - markdown-it-py: 3.0.0 - markupsafe: 2.1.3 - matplotlib: 3.8.0 - matplotlib-inline: 0.1.6 - mccabe: 0.7.0 - mdurl: 0.1.2 - mistune: 2.0.5 - more-itertools: 9.1.0 - mpmath: 1.2.1 - multidict: 6.0.4 - munch: 2.5.0 - mypy: 1.3.0 - mypy-extensions: 1.0.0 - nbclient: 0.6.7 - nbconvert: 7.4.0 - nbformat: 5.8.0 - nbmake: 1.4.3 - nbsphinx: 0.8.8 - nest-asyncio: 1.5.6 - netcdf4: 1.6.2 - networkx: 3.1 - notebook-shim: 0.2.3 - numexpr: 2.8.4 - numpy: 1.25.2 - oauthlib: 3.2.1 - odc-geo: 0.1.2 - omegaconf: 2.3.0 - openpyxl: 3.1.2 - ordered-set: 4.0.2 - overrides: 7.3.1 - packaging: 23.1 - pandas: 2.0.2 - pandocfilters: 1.5.0 - parso: 0.8.3 - pathspec: 0.11.1 - pexpect: 4.8.0 - pickleshare: 0.7.5 - pillow: 9.5.0 - pip: 23.0 - pkginfo: 1.9.6 - planetary-computer: 0.4.9 - platformdirs: 3.5.3 - pluggy: 1.0.0 - poetry-core: 1.6.1 - pretrainedmodels: 0.7.4 - prometheus-client: 0.17.0 - prompt-toolkit: 3.0.38 - protobuf: 3.20.3 - psutil: 5.9.5 - ptyprocess: 0.7.0 - pure-eval: 0.2.2 - pyasn1: 0.4.8 - pyasn1-modules: 0.2.8 - pybind11: 2.10.1 - pycocotools: 2.0.6 - pycodestyle: 2.11.0 - pycparser: 2.21 - pydantic: 1.10.9 - pydocstyle: 6.2.1 - pyflakes: 3.1.0 - pygeos: 0.10 - pygments: 2.15.1 - pyjwt: 2.4.0 - pyparsing: 3.0.9 - pyproj: 3.6.0 - pyproject-hooks: 1.0.0 - pyrsistent: 0.19.3 - pyshp: 2.1.0 - pystac: 1.4.0 - pystac-client: 0.5.1 - pytest: 7.3.2 - pytest-cov: 4.0.0 - python-dateutil: 2.8.2 - python-dotenv: 0.19.2 - python-editor: 1.0.4 - python-json-logger: 2.0.7 - python-multipart: 0.0.5 - pytorch-lightning: 2.0.0 - pytorch-sphinx-theme: 0.0.24 - pytz: 2023.3 - pyupgrade: 3.3.1 - pyyaml: 6.0 - pyzmq: 25.0.2 - radiant-mlhub: 0.5.1 - rarfile: 4.1 - rasterio: 1.3.8 - readchar: 4.0.5 - readme-renderer: 37.3 - requests: 2.31.0 - requests-oauthlib: 1.3.1 - requests-toolbelt: 1.0.0 - rfc3339-validator: 0.1.4 - rfc3986: 2.0.0 - rfc3986-validator: 0.1.1 - rich: 13.4.2 - rioxarray: 0.4.1.post0 - rsa: 4.9 - rtree: 1.0.1 - safetensors: 0.3.1 - scikit-learn: 1.3.1 - scipy: 1.10.1 - segmentation-models-pytorch: 0.3.3 - send2trash: 1.8.0 - setuptools: 63.4.3 - setuptools-scm: 7.1.0 - shapely: 1.8.4 - six: 1.16.0 - sniffio: 1.3.0 - snowballstemmer: 2.2.0 - snuggs: 1.4.1 - soupsieve: 2.4.1 - sphinx: 5.3.0 - sphinx-design: 0.4.1 - sphinx-rtd-theme: 1.2.2 - sphinxcontrib-applehelp: 1.0.2 - sphinxcontrib-devhelp: 1.0.2 - sphinxcontrib-htmlhelp: 2.0.0 - sphinxcontrib-jquery: 4.1 - sphinxcontrib-jsmath: 1.0.1 - sphinxcontrib-programoutput: 0.15 - sphinxcontrib-qthelp: 1.0.3 - sphinxcontrib-serializinghtml: 1.1.9 - stack-data: 0.6.2 - starlette: 0.27.0 - starsessions: 1.3.0 - sympy: 1.11.1 - tensorboard: 2.13.0 - tensorboard-data-server: 0.7.0 - tensorboard-plugin-wit: 1.8.1 - terminado: 0.17.1 - threadpoolctl: 3.1.0 - timm: 0.9.2 - tinycss2: 1.1.1 - tokenize-rt: 4.2.1 - tomli: 2.0.1 - torch: 2.0.1 - torchmetrics: 1.1.1 - torchvision: 0.15.2 - tornado: 6.2 - tqdm: 4.65.0 - traitlets: 5.9.0 - trove-classifiers: 2023.3.9 - twine: 4.0.2 - typeshed-client: 2.1.0 - typing-extensions: 4.6.3 - tzdata: 2023.3 - uri-template: 1.2.0 - urllib3: 1.26.12 - uvicorn: 0.20.0 - vermin: 1.5.2 - wcwidth: 0.2.5 - webcolors: 1.11.1 - webencodings: 0.5.1 - websocket-client: 1.5.1 - websockets: 10.4 - werkzeug: 2.3.4 - wheel: 0.37.1 - widgetsnbextension: 4.0.3 - xarray: 2023.7.0 - yarl: 1.8.1 - zipfile-deflate64: 0.2.0 - zipp: 3.8.1 - zope.event: 4.6 - zope.interface: 5.4.0 * System: - OS: Darwin - architecture: - 64bit - - processor: arm - python: 3.11.4 - release: 22.6.0 - version: Darwin Kernel Version 22.6.0: Wed Jul 5 22:21:53 PDT 2023; root:xnu-8796.141.3~6/RELEASE_ARM64_T6020

More info

No response

cc @carmocca @mauvilsa

adamjstewart commented 1 year ago

To try to debug if this was a jsonargparse bug or a lightning bug, I tried using:

from jsonargparse import CLI
from torchgeo.datamodules import *

if __name__ == "__main__":
    #CLI(SeasonalContrastS2DataModule)
    CLI(So2SatDataModule)

For both data modules, jsonargparse detects the correct types or ignores the kwargs altogether if it can't. And it doesn't seem to inject fake default values when loading a config.yaml, it leaves them blank. I'm guessing both of these are introduced by lightning, but I'm not sure where.

adamjstewart commented 1 year ago

Also, the problem still exists when I replace jsonargparse with omegaconf a la this tutorial.

mauvilsa commented 1 year ago

When parsing So2SatDataModule, jsonargparse notices that kwargs may contain a version parameter. If it doesn't, the default is "2". However, jsonargparse detects the type as int and uses a default value of 2. Since "2" is in the dict but 2 isn't, the dataset fails to initialize.

I was not able to reproduce this. The type for version inferred is str. A better way to observe this is:

from jsonargparse import ArgumentParser
from torchgeo.datamodules import So2SatDataModule

parser = ArgumentParser()
parser.add_class_arguments(So2SatDataModule, 'data')
parser.print_help()

Even when running the unit test that failed in torchgeo, the type I got was str and the tests did not fail (run with commit f4d1436).

When parsing SeasonalContrastS2DataModule, jsonargparse notices that kwargs may contain a bands parameter. If it doesn't, the default is ["B4", "B3", "B2"]. However, jsonargparse detects the type as Any and uses a default value of None. Obviously, None is invalid.

Resolving a default from code like in seco.py#L36 is currently not supported. Though, getting the default is not strictly necessary. The actual problem is a bug in jsonargparse, specifically in line _parameter_resolvers.py#L653. I will fix this.

I'm surprised that it's trying to poke into my kwargs at all. Isn't that the point of dict_kwargs, for things that can't be type checked?

Not checking the type is for unresolved-parameters. version is resolved, since it is a kwargs.get case, which is supported as explained in ast-resolver.

the problem still exists when I replace jsonargparse with omegaconf a la this tutorial.

That does not replace jsonargparse with omegaconf. It relpaces pyyaml with omegaconf. In both cases jsonargparse is used.

adamjstewart commented 1 year ago

I was not able to reproduce this.

To clarify I think the issue is that it's reading "2" from my config.yaml and deciding to cast it to an integer for some reason.

Even when running the unit test that failed in torchgeo, the type I got was str and the tests did not fail (run with commit f4d1436).

This part surprises me. It fails in CI and it fails locally for me. I'm not sure what would be different in your environment that would cause the test to pass. Maybe a development version of something?

I will fix this.

Thanks!

mauvilsa commented 1 year ago

The actual problem is a bug in jsonargparse, specifically in line _parameter_resolvers.py#L653. I will fix this.

This is now fixed and it is part of jsonargparse 4.25.0 which has just been released.

To clarify I think the issue is that it's reading "2" from my config.yaml and deciding to cast it to an integer for some reason.

What is the motivation for doing version = kwargs.get("version", "2") instead of version being an actual named parameter of __init__? If there is no real motivation, I would suggest to change it. Note that like it is now, it is not possible to know what is the expected type for version. With jsonargparse v4.25.0, the type will be Union[str, Any]. Since it accepts any type, if someone in a yaml config writes version: 2, then pyyaml parses it as an int and that is how the class will get it. If it were version: "2", then it is passed as str. Handling of both int and str in the code is required, because without a proper type hint, jsonargparse will not warn about invalid values.

adamjstewart commented 1 year ago

This is now fixed and it is part of jsonargparse 4.25.0 which has just been released.

I can confirm that this fixes the issue, thanks for the quick fix and release!

What is the motivation for doing ...

All of our data modules are wrappers around our datasets. So as to avoid documenting parameters twice, we keep all dataset parameters as kwargs. It just so happens that this datamodule needs access to one of the dataset parameters. If we make this a formal parameter of the datamodule, we can avoid this issue, but then we need to manually add it to kwargs. I'm not opposed to this workaround though, might do this in the meantime if we can't fix this bug.

If it were version: "2", then it is passed as str.

That's the weird part, it is version: "2": https://github.com/microsoft/torchgeo/pull/1559/files#diff-0d7ed937c136111e81d0b80d253d1e11f38cbd4857dcdedd4dd9c392e80c366dR15

For some reason jsonargparse/lightning is casting it to an int, which causes the aforementioned bug.

adamjstewart commented 1 year ago

This is now fixed and it is part of jsonargparse 4.25.0 which has just been released.

One last thing that needs to change. Can this constraint be relaxed?

Also, this line might need to be updated if the min is now 4.18.

carmocca commented 1 year ago

I agree with both suggestions @adamjstewart. Would you like to open a PR applying them?

mauvilsa commented 1 year ago

@adamjstewart a small comment about your torchgeo pull request. Using as requirement lightning[pytorch-extra]==2.0.9 forces the install of all optional requirements. If that is the intention, then all good. But the pull request is only about LightningCLI, and for this not all of those requirements are necessary. To install only what is needed for CLI, remove [pytorch-extra] and add a separate requirement as jsonargparse[signatures].

adamjstewart commented 1 year ago

Yeah, I'm torn about which to use. We also use tensorboard, and may also use omegaconf/hydra in the future. I don't love adding a dep on jsonargparse because we don't actually use it in torchgeo itself, but that would help minimize deps.

mauvilsa commented 1 year ago

For some reason jsonargparse/lightning is casting it to an int, which causes the aforementioned bug.

I now managed to reproduce the issue, and would be considered a bug in jsonargparse. Though, I am not sure how to fix it.

Independent from this, note that people will be writing the configs, and might be probable that someone writes version: 2. This would cause a failure and it is not related to any bug. Better to make it such that when an int is given, then cast it to a string.

adamjstewart commented 2 months ago

Pretty sure this was fixed a long time ago, closing.

adamjstewart commented 2 months ago

I spoke too soon, the So2SatDataModule bug is still occurring for me: