Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.42k stars 3.39k forks source link

WandBLogger: Can't set log_model from LightningCLI due to problem in type hint #18370

Closed SebastianGer closed 1 year ago

SebastianGer commented 1 year ago

Bug description

The WandB logger has an init argument log_model of type Union[str, bool] (see the class definition).

I can set this value to True or False as part of a config file that I pass to the LightningCLI. However, this does not work via the command line. Setting --trainer.logger.init_args.log_model=True results in the value being represented internally as 'True', of type String.

From what I can tell, this can be fixed by switching the order of types in the Union type hint, to prefer the boolean interpretation, if both are possible. Even better would be to use the Literal type, since only one String value should actually be accepted, according to the docs: Union[Literal["all"],bool]. Both work, according to my testing.

What version are you seeing the problem on?

v2.0

How to reproduce the bug

This is an issue in jsonargparse, which is the basis of LightningCLI. 

from typing import Union
import jsonargparse

parser = jsonargparse.ArgumentParser()
parser.add_argument("--bool_arg", type=bool)
parser.add_argument("--union_arg", type=Union[str,bool])
parser.add_argument("--union_arg2", type=Union[bool,str])

args = parser.parse_args()
print(type(args.bool_arg), type(args.union_arg), type(args.union_arg2))

Running this with `python src/test_bool_parsing.py --bool_arg=True --union_arg=True --union_arg2=True` results in `<class 'bool'> <class 'str'> <class 'bool'>`. The order of arguments in the Union type hint clearly is the aspect that matters. 

We can also pass `'True'` or `'yes'`, i.e. something that is clearly a String (based on the quotation marks), to `union_arg2`, and it still is interpreted as a bool. Apparently, because that's the first type in the Union. Passing the same values to `union_arg` results only in Strings, no matter if quotation marks are used or not.

Error messages and logs

No response

Environment

Current environment * CUDA: - GPU: - Tesla T4 - Tesla T4 - Tesla T4 - Tesla T4 - available: True - version: 11.7 * Lightning: - efficientnet-pytorch: 0.7.1 - lightning: 2.0.1 - lightning-cloud: 0.5.32 - lightning-utilities: 0.8.0 - pytorch-lightning: 2.0.1 - segmentation-models-pytorch: 0.3.2 - torch: 2.0.0 - torchaudio: 2.0.1 - torchmetrics: 0.11.4 - torchvision: 0.15.1 * Packages: - absl-py: 1.4.0 - affine: 2.4.0 - aiofiles: 22.1.0 - aiohttp: 3.8.4 - aiosignal: 1.3.1 - aiosqlite: 0.19.0 - alabaster: 0.7.12 - antlr4-python3-runtime: 4.9.3 - anyio: 3.6.2 - appdirs: 1.4.4 - argon2-cffi: 21.3.0 - argon2-cffi-bindings: 21.2.0 - arrow: 1.2.3 - asciitree: 0.3.3 - asn1crypto: 1.5.1 - asttokens: 2.2.1 - astunparse: 1.6.3 - async-timeout: 4.0.2 - atomicwrites: 1.4.0 - attrs: 21.4.0 - autopep8: 2.0.2 - babel: 2.10.1 - backcall: 0.2.0 - backports.entry-points-selectable: 1.1.1 - backports.functools-lru-cache: 1.6.4 - bcrypt: 3.2.2 - beautifulsoup4: 4.12.1 - beniget: 0.4.1 - bitstring: 3.1.9 - bleach: 6.0.0 - blessed: 1.20.0 - blist: 1.3.6 - bokeh: 2.4.3 - bottleneck: 1.3.4 - cachecontrol: 0.12.11 - cachetools: 5.3.0 - cachy: 0.3.0 - certifi: 2021.10.8 - cffi: 1.15.0 - cfgrib: 0.9.10.3 - cftime: 1.6.2 - chardet: 4.0.0 - charset-normalizer: 2.0.12 - cleo: 0.8.1 - click: 8.1.3 - click-plugins: 1.1.1 - cligj: 0.7.2 - clikit: 0.6.2 - cloudpickle: 2.2.1 - cmake: 3.26.1 - colorama: 0.4.4 - comm: 0.1.3 - contourpy: 1.0.7 - crashtest: 0.3.1 - crc32c: 2.3.post0 - croniter: 1.3.8 - cryptography: 37.0.1 - cycler: 0.11.0 - cython: 0.29.28 - dask: 2023.3.2 - dateutils: 0.6.12 - deap: 1.3.1 - debugpy: 1.6.7 - decorator: 5.1.1 - deepdiff: 6.3.0 - defusedxml: 0.7.1 - distlib: 0.3.4 - distributed: 2023.3.2 - dnspython: 2.3.0 - docker-pycreds: 0.4.0 - docopt: 0.6.2 - docstring-parser: 0.15 - docutils: 0.17.1 - eccodes: 1.5.2 - ecdsa: 0.17.0 - editables: 0.3 - efficientnet-pytorch: 0.7.1 - einops: 0.6.1 - email-validator: 1.3.1 - entrypoints: 0.4 - executing: 1.2.0 - fastapi: 0.88.0 - fasteners: 0.18 - fastjsonschema: 2.16.3 - filelock: 3.6.0 - findlibs: 0.0.2 - fiona: 1.9.2 - flatbuffers: 23.3.3 - flit: 3.7.1 - flit-core: 3.7.1 - flox: 0.6.10 - fonttools: 4.39.3 - fqdn: 1.5.1 - frozenlist: 1.3.3 - fsspec: 2023.3.0 - future: 0.18.2 - gast: 0.5.3 - geopandas: 0.12.2 - gitdb: 4.0.10 - gitpython: 3.1.31 - glob2: 0.7 - google-auth: 2.17.3 - google-auth-oauthlib: 1.0.0 - google-pasta: 0.2.0 - googleapis-common-protos: 1.59.0 - grpcio: 1.54.0 - h11: 0.14.0 - h5netcdf: 1.1.0 - h5py: 3.7.0 - heapdict: 1.0.1 - html5lib: 1.1 - httpcore: 0.16.3 - httplib2: 0.22.0 - httptools: 0.5.0 - httpx: 0.23.3 - huggingface-hub: 0.13.4 - idna: 3.3 - imageio: 2.27.0 - imagesize: 1.3.0 - importlib-metadata: 6.1.0 - importlib-resources: 5.7.1 - iniconfig: 1.1.1 - inquirer: 3.1.3 - intervaltree: 3.1.0 - intreehooks: 1.0 - ipaddress: 1.0.23 - ipykernel: 6.22.0 - ipython: 8.12.0 - ipython-genutils: 0.2.0 - ipywidgets: 8.0.6 - isoduration: 20.11.0 - itsdangerous: 2.1.2 - jax: 0.4.8 - jedi: 0.18.2 - jeepney: 0.8.0 - jinja2: 3.1.2 - joblib: 1.2.0 - json5: 0.9.11 - jsonargparse: 4.20.1 - jsonpointer: 2.3 - jsonschema: 4.17.3 - jupyter-client: 8.2.0 - jupyter-core: 5.3.0 - jupyter-events: 0.6.3 - jupyter-server: 2.5.0 - jupyter-server-fileid: 0.9.0 - jupyter-server-terminals: 0.4.4 - jupyter-server-ydoc: 0.8.0 - jupyter-ydoc: 0.2.4 - jupyterlab: 3.6.3 - jupyterlab-pygments: 0.2.2 - jupyterlab-server: 2.22.1 - jupyterlab-widgets: 3.0.7 - keras: 2.12.0 - keyring: 23.5.0 - keyrings.alt: 4.1.0 - kiwisolver: 1.4.4 - liac-arff: 2.5.0 - libclang: 16.0.0 - lightning: 2.0.1 - lightning-cloud: 0.5.32 - lightning-utilities: 0.8.0 - lit: 16.0.0 - llvmlite: 0.39.1 - locket: 1.0.0 - lockfile: 0.12.2 - lz4: 4.3.2 - markdown: 3.4.3 - markdown-it-py: 2.2.0 - markupsafe: 2.1.1 - matplotlib: 3.7.1 - matplotlib-inline: 0.1.6 - mdurl: 0.1.2 - mistune: 2.0.5 - ml-dtypes: 0.1.0 - mock: 4.0.3 - more-itertools: 8.12.0 - mpi4py: 3.1.3 - mpmath: 1.2.1 - msgpack: 1.0.3 - multidict: 6.0.4 - munch: 2.5.0 - nbclassic: 0.5.5 - nbclient: 0.7.3 - nbconvert: 7.3.1 - nbformat: 5.8.0 - nc-time-axis: 1.4.1 - nest-asyncio: 1.5.6 - netaddr: 0.8.0 - netcdf4: 1.6.3 - netifaces: 0.11.0 - networkx: 3.1 - notebook: 6.5.4 - notebook-shim: 0.2.2 - numba: 0.56.4 - numbagg: 0.2.2 - numcodecs: 0.11.0 - numexpr: 2.8.1 - numpy: 1.22.3 - numpy-groupies: 0.9.20 - nvidia-cublas-cu11: 11.10.3.66 - nvidia-cuda-cupti-cu11: 11.7.101 - nvidia-cuda-nvrtc-cu11: 11.7.99 - nvidia-cuda-runtime-cu11: 11.7.99 - nvidia-cudnn-cu11: 8.5.0.96 - nvidia-cufft-cu11: 10.9.0.58 - nvidia-curand-cu11: 10.2.10.91 - nvidia-cusolver-cu11: 11.4.0.1 - nvidia-cusparse-cu11: 11.7.4.91 - nvidia-nccl-cu11: 2.14.3 - nvidia-nvtx-cu11: 11.7.91 - oauth2client: 4.1.3 - oauthlib: 3.2.2 - omegaconf: 2.3.0 - opencv-python: 4.7.0.72 - opt-einsum: 3.3.0 - ordered-set: 4.1.0 - orjson: 3.8.9 - packaging: 23.0 - pandas: 1.4.2 - pandocfilters: 1.5.0 - paramiko: 2.10.4 - parso: 0.8.3 - partd: 1.3.0 - pastel: 0.2.1 - pathlib2: 2.3.7.post1 - pathspec: 0.9.0 - pathtools: 0.1.2 - pbr: 5.8.1 - pexpect: 4.8.0 - pickleshare: 0.7.5 - pillow: 9.5.0 - pip: 23.2.1 - pkginfo: 1.8.2 - platformdirs: 3.2.0 - plotly: 5.15.0 - pluggy: 1.0.0 - ply: 3.11 - poetry: 1.1.13 - poetry-core: 1.0.8 - pooch: 1.7.0 - pretrainedmodels: 0.7.4 - prometheus-client: 0.16.0 - prompt-toolkit: 3.0.38 - protobuf: 3.20.3 - psutil: 5.9.0 - ptyprocess: 0.7.0 - pure-eval: 0.2.2 - py: 1.11.0 - py-expression-eval: 0.3.14 - pyarrow: 11.0.0 - pyasn1: 0.4.8 - pyasn1-modules: 0.3.0 - pybind11: 2.9.2 - pycodestyle: 2.10.0 - pycparser: 2.21 - pycrypto: 2.6.1 - pydantic: 1.10.7 - pygments: 2.14.0 - pyjwt: 2.6.0 - pylev: 1.4.0 - pynacl: 1.5.0 - pyparsing: 3.0.8 - pyproj: 3.5.0 - pyrsistent: 0.18.1 - pytest: 7.1.2 - python-dateutil: 2.8.2 - python-dotenv: 1.0.0 - python-editor: 1.0.4 - python-json-logger: 2.0.7 - python-multipart: 0.0.6 - pythran: 0.11.0 - pytoml: 0.1.21 - pytorch-lightning: 2.0.1 - pytz: 2022.1 - pyyaml: 6.0 - pyzmq: 25.0.2 - rasterio: 1.3.6 - readchar: 4.0.5 - regex: 2022.4.24 - requests: 2.28.2 - requests-oauthlib: 1.3.1 - requests-toolbelt: 0.9.1 - rfc3339-validator: 0.1.4 - rfc3986: 1.5.0 - rfc3986-validator: 0.1.1 - rich: 13.3.3 - rsa: 4.9 - scandir: 1.10.0 - scikit-learn: 1.2.2 - scipy: 1.8.1 - seaborn: 0.12.2 - secretstorage: 3.3.2 - segmentation-models-pytorch: 0.3.2 - semantic-version: 2.9.0 - send2trash: 1.8.0 - sentry-sdk: 1.19.1 - setproctitle: 1.3.2 - setuptools: 62.1.0 - setuptools-rust: 1.3.0 - setuptools-scm: 6.4.2 - shapely: 2.0.1 - shellingham: 1.4.0 - simplegeneric: 0.8.1 - simplejson: 3.17.6 - six: 1.16.0 - smmap: 5.0.0 - sniffio: 1.3.0 - snowballstemmer: 2.2.0 - snuggs: 1.4.7 - sortedcontainers: 2.4.0 - soupsieve: 2.4 - sphinx: 4.5.0 - sphinx-bootstrap-theme: 0.8.1 - sphinxcontrib-applehelp: 1.0.2 - sphinxcontrib-devhelp: 1.0.2 - sphinxcontrib-htmlhelp: 2.0.0 - sphinxcontrib-jsmath: 1.0.1 - sphinxcontrib-qthelp: 1.0.3 - sphinxcontrib-serializinghtml: 1.1.5 - sphinxcontrib-websupport: 1.2.4 - stack-data: 0.6.2 - starlette: 0.22.0 - starsessions: 1.3.0 - sympy: 1.11.1 - tabulate: 0.8.9 - tblib: 1.7.0 - tenacity: 8.2.2 - tensorboard: 2.12.3 - tensorboard-data-server: 0.7.0 - tensorflow: 2.12.0 - tensorflow-estimator: 2.12.0 - tensorflow-io-gcs-filesystem: 0.32.0 - termcolor: 2.3.0 - terminado: 0.17.1 - tfrecord: 1.14.3 - threadpoolctl: 3.1.0 - timm: 0.6.12 - tinycss2: 1.2.1 - toml: 0.10.2 - tomli: 2.0.1 - tomli-w: 1.0.0 - tomlkit: 0.10.2 - toolz: 0.12.0 - torch: 2.0.0 - torchaudio: 2.0.1 - torchmetrics: 0.11.4 - torchvision: 0.15.1 - tornado: 6.2 - tqdm: 4.65.0 - traitlets: 5.9.0 - triton: 2.0.0 - typeshed-client: 2.2.0 - typing-extensions: 4.2.0 - ujson: 5.2.0 - uri-template: 1.2.0 - urllib3: 1.26.15 - uvicorn: 0.21.1 - uvloop: 0.17.0 - virtualenv: 20.14.1 - wandb: 0.15.8 - watchfiles: 0.19.0 - wcwidth: 0.2.5 - webcolors: 1.13 - webencodings: 0.5.1 - websocket-client: 1.5.1 - websockets: 11.0.1 - werkzeug: 2.3.3 - wheel: 0.37.1 - widgetsnbextension: 4.0.7 - wrapt: 1.14.1 - xarray: 2023.3.0 - xgboost: 1.7.5 - xlrd: 2.0.1 - y-py: 0.5.9 - yarl: 1.8.2 - ypy-websocket: 0.8.2 - zarr: 2.14.2 - zict: 2.2.0 - zipfile36: 0.1.3 - zipp: 3.8.0 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.10.4 - release: 4.18.0-425.19.2.el8_7.x86_64 - version: #1 SMP Tue Apr 4 22:38:11 UTC 2023

More info

No response

cc @borda @carmocca @mauvilsa

awaelchli commented 1 year ago

@SebastianGer Thanks for reporting. Does the problem still persist when upgrading jsonargparse to the latest?

SebastianGer commented 1 year ago

@SebastianGer Thanks for reporting. Does the problem still persist when upgrading jsonargparse to the latest?

Yes. After upgrading to jsonargparse==4.23.1 (latest version, according to docs), the results are still the same.

The same thing happens when I exchange bool for float. I find it difficult to find the responsible spot in the jsonargparse code, but I assume they just iterate through possible types and pick the first one that fits, to resolve the type ambiguity. Seems like a reasonable design decision, just one that is easy to overlook.

awaelchli commented 1 year ago

@mauvilsa Do you agree with the proposal from @SebastianGer? Switching the order or using Literal?

mauvilsa commented 1 year ago

I assume they just iterate through possible types and pick the first one that fits, to resolve the type ambiguity.

Correct.

Do you agree with the proposal from @SebastianGer? Switching the order or using Literal?

If only "all" should be accepted as a string, then I would agree that the Literal is better.

awaelchli commented 1 year ago

@SebastianGer are you interested in sending a PR with the change?

SebastianGer commented 1 year ago

@awaelchli Sure. Will take a few days, but I'll try.