Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.4k stars 3.38k forks source link

Lightning Cli: KeyError for setting default argument and linking arguments in class_path and init_args #17346

Closed janblumenkamp closed 1 year ago

janblumenkamp commented 1 year ago

Bug description

I am attempting to set a default argument in the init_args of a class provided by the class_path. This results in a KeyError. Furthermore, related to this issue, but similarly, it is not clear to me how I can link arguments in the trainer.callback field. Intuitively, I should be able to index the list, but this is not possible.

What version are you seeing the problem on?

2.0+

How to reproduce the bug

from lightning.pytorch.cli import LightningCLI
from lightning.pytorch.demos.boring_classes import DemoModel, BoringDataModule

class MyLightningCLI(LightningCLI):
    def add_arguments_to_parser(self, parser):
        parser.set_defaults({"trainer.logger.init_args.name": "Test"})  # Issue A
        parser.link_arguments(
            "trainer.logger.init_args.save_dir",
            "trainer.callbacks[1].init_args.save_dir",
        )  # Issue B

def main_cli():
    MyLightningCLI(DemoModel, BoringDataModule)

if __name__ == "__main__":
    main_cli()

Error messages and logs

Subissue A:

KeyError: 'No action for destination key "trainer.logger.init_args.name" to set its default.'

Subissue B:

ValueError: No action for key "trainer.callbacks[1].init_args.save_dir".

Environment

Current environment * CUDA: - GPU: None - available: False - version: None * Lightning: - lightning: 2.0.1 - lightning-api-access: 0.0.5 - lightning-cloud: 0.5.32 - lightning-fabric: 2.0.1 - lightning-utilities: 0.8.0 - pytorch-lightning: 1.9.0 - torch: 2.0.0 - torch-cluster: 1.6.1 - torch-geometric: 2.3.0 - torch-scatter: 2.1.1 - torch-sparse: 0.6.17 - torch-spline-conv: 1.2.2 - torchaudio: 0.13.1 - torchmetrics: 0.11.4 - torchvision: 0.15.1 * Packages: - absl-py: 1.0.0 - affine: 2.4.0 - aiobotocore: 2.4.2 - aiohttp: 3.8.3 - aioitertools: 0.11.0 - aiosignal: 1.3.1 - altair: 4.2.2 - antlr4-python3-runtime: 4.9.3 - anyio: 3.6.2 - appdirs: 1.4.4 - argcomplete: 2.0.0 - argparse: 1.4.0 - arrow: 1.2.3 - astunparse: 1.6.3 - async-timeout: 4.0.2 - attrs: 22.1.0 - beautifulsoup4: 4.12.1 - black: 22.12.0 - bleach: 6.0.0 - blessed: 1.20.0 - blinker: 1.6 - bokeh: 2.4.3 - botocore: 1.27.59 - cachetools: 5.3.0 - catkin-pkg: 0.5.2 - certifi: 2022.12.7 - cffi: 1.15.1 - charset-normalizer: 2.1.1 - click: 8.1.3 - click-plugins: 1.1.1 - cligj: 0.7.2 - colcon-argcomplete: 0.3.3 - colcon-bash: 0.4.2 - colcon-cd: 0.1.1 - colcon-cmake: 0.2.26 - colcon-common-extensions: 0.2.1 - colcon-core: 0.10.0 - colcon-defaults: 0.2.6 - colcon-devtools: 0.2.3 - colcon-library-path: 0.2.1 - colcon-metadata: 0.2.5 - colcon-notification: 0.2.13 - colcon-output: 0.2.12 - colcon-package-information: 0.3.3 - colcon-package-selection: 0.2.10 - colcon-parallel-executor: 0.2.4 - colcon-pkg-config: 0.1.0 - colcon-powershell: 0.3.7 - colcon-python-setup-py: 0.2.7 - colcon-recursive-crawl: 0.2.1 - colcon-ros: 0.3.23 - colcon-test-result: 0.3.8 - colcon-zsh: 0.4.0 - contextlib2: 21.6.0 - coverage: 6.4.3 - croniter: 1.3.8 - cryptography: 37.0.4 - cycler: 0.11.0 - dateutils: 0.6.12 - decorator: 5.1.1 - deepdiff: 6.3.0 - distlib: 0.3.5 - distro: 1.7.0 - dm-tree: 0.1.8 - dnspython: 2.3.0 - docker: 6.0.1 - docker-pycreds: 0.4.0 - docstring-parser: 0.15 - docutils: 0.19 - email-validator: 1.3.1 - empy: 3.3.4 - entrypoints: 0.4 - etils: 1.1.1 - fastapi: 0.88.0 - filelock: 3.10.3 - flake8: 5.0.4 - flake8-blind-except: 0.2.1 - flake8-builtins: 1.5.3 - flake8-class-newline: 1.6.0 - flake8-comprehensions: 3.10.0 - flake8-deprecated: 1.3 - flake8-docstrings: 1.6.0 - flake8-import-order: 0.18.1 - flake8-quotes: 3.3.1 - flatbuffers: 23.3.3 - fonttools: 4.34.4 - frozenlist: 1.3.3 - fsspec: 2022.11.0 - gast: 0.4.0 - gitdb: 4.0.10 - gitpython: 3.1.31 - google-auth: 2.16.2 - google-auth-oauthlib: 0.4.6 - google-pasta: 0.2.0 - googleapis-common-protos: 1.59.0 - grpcio: 1.51.3 - h11: 0.14.0 - h5py: 3.8.0 - habitat-sim: 0.2.3 - httpcore: 0.16.3 - httptools: 0.5.0 - httpx: 0.23.3 - hydra-core: 1.3.2 - idna: 3.4 - ifcfg: 0.23 - imageio: 2.26.0 - imageio-ffmpeg: 0.4.8 - importlib-metadata: 4.12.0 - importlib-resources: 5.12.0 - iniconfig: 1.1.1 - inquirer: 3.1.3 - itsdangerous: 2.1.2 - jinja2: 3.1.2 - jmespath: 1.0.1 - joblib: 1.2.0 - jsonargparse: 4.20.1 - jsonschema: 4.17.3 - keras: 2.11.0 - kiwisolver: 1.4.4 - lark: 1.1.2 - lark-parser: 0.12.0 - libclang: 15.0.6.1 - lightning: 2.0.1 - lightning-api-access: 0.0.5 - lightning-cloud: 0.5.32 - lightning-fabric: 2.0.1 - lightning-utilities: 0.8.0 - llvmlite: 0.39.1 - lxml: 4.9.1 - magnum: 0.0.0 - markdown: 3.4.2 - markdown-it-py: 2.2.0 - markupsafe: 2.1.2 - matplotlib: 3.5.3 - mccabe: 0.7.0 - mdurl: 0.1.2 - ml-collections: 0.1.0 - mock: 4.0.3 - mpmath: 1.3.0 - multidict: 6.0.4 - mypy: 0.971 - mypy-extensions: 0.4.3 - netifaces: 0.11.0 - networkx: 3.0 - nose: 1.3.7 - numba: 0.56.4 - numpy: 1.24.2 - numpy-quaternion: 2022.4.3 - oauthlib: 3.2.2 - omegaconf: 2.3.0 - opencv-python: 4.7.0.68 - opt-einsum: 3.3.0 - ordered-set: 4.1.0 - orjson: 3.8.9 - packaging: 21.3 - pandas: 1.5.3 - panel: 0.14.4 - param: 1.13.0 - pathspec: 0.10.3 - pathtools: 0.1.2 - pep8: 1.7.1 - pillow: 9.2.0 - pip: 23.0 - platformdirs: 2.6.0 - pluggy: 1.0.0 - ply: 3.11 - promise: 2.3 - protobuf: 3.19.6 - psutil: 5.9.1 - py: 1.11.0 - pyarrow: 11.0.0 - pyasn1: 0.4.8 - pyasn1-modules: 0.2.8 - pycodestyle: 2.9.1 - pycparser: 2.21 - pyct: 0.5.0 - pydantic: 1.10.7 - pydeck: 0.8.0 - pydocstyle: 6.1.1 - pydot: 1.4.2 - pyflakes: 2.5.0 - pygments: 2.14.0 - pygraphviz: 1.9 - pyjwt: 2.6.0 - pympler: 1.0.1 - pyparsing: 3.0.9 - pyqt3d: 5.15.5 - pyqt5: 5.15.7 - pyqt5-sip: 12.11.0 - pyqtchart: 5.15.6 - pyqtdatavisualization: 5.15.5 - pyqtnetworkauth: 5.15.5 - pyqtpurchasing: 5.15.5 - pyqtwebengine: 5.15.6 - pyrsistent: 0.19.3 - pytest: 7.1.2 - pytest-cov: 3.0.0 - pytest-mock: 3.8.2 - pytest-repeat: 0.9.1 - pytest-rerunfailures: 10.2 - python-dateutil: 2.8.2 - python-dotenv: 1.0.0 - python-editor: 1.0.4 - python-multipart: 0.0.6 - pytorch-lightning: 1.9.0 - pytz: 2022.7 - pytz-deprecation-shim: 0.1.0.post0 - pyvisgraph: 0.2.1 - pyviz-comms: 2.2.1 - pyyaml: 6.0 - pyzmq: 24.0.1 - rasterio: 1.3.6 - readchar: 4.0.5 - redis: 4.5.4 - requests: 2.28.2 - requests-oauthlib: 1.3.1 - rfc3986: 1.5.0 - rich: 13.3.3 - rosdep: 0.22.1 - rosdistro: 0.9.0 - rosinstall-generator: 0.1.22 - rospkg: 1.4.0 - rsa: 4.9 - s3fs: 2022.11.0 - scikit-learn: 1.2.1 - scipy: 1.10.0 - seaborn: 0.12.2 - semver: 3.0.0 - sentry-sdk: 1.17.0 - setproctitle: 1.3.2 - setuptools: 63.4.3 - shapely: 1.8.2 - sip: 6.6.2 - six: 1.16.0 - smmap: 5.0.0 - sniffio: 1.3.0 - snowballstemmer: 2.2.0 - snuggs: 1.4.7 - soupsieve: 2.4 - starlette: 0.22.0 - starsessions: 1.3.0 - streamlit: 1.20.0 - sympy: 1.11.1 - tensorboard: 2.11.2 - tensorboard-data-server: 0.6.1 - tensorboard-plugin-wit: 1.8.1 - tensorboardx: 2.6 - tensorflow-datasets: 4.8.3 - tensorflow-estimator: 2.11.0 - tensorflow-macos: 2.11.0 - tensorflow-metadata: 1.12.0 - termcolor: 2.2.0 - threadpoolctl: 3.1.0 - toml: 0.10.2 - tomli: 2.0.1 - toolz: 0.12.0 - torch: 2.0.0 - torch-cluster: 1.6.1 - torch-geometric: 2.3.0 - torch-scatter: 2.1.1 - torch-sparse: 0.6.17 - torch-spline-conv: 1.2.2 - torchaudio: 0.13.1 - torchmetrics: 0.11.4 - torchvision: 0.15.1 - tornado: 6.2 - tqdm: 4.64.1 - traitlets: 5.9.0 - typeshed-client: 2.2.0 - typing-extensions: 4.3.0 - tzdata: 2023.3 - tzlocal: 4.3 - ujson: 5.7.0 - urllib3: 1.26.14 - uvicorn: 0.21.1 - uvloop: 0.17.0 - validators: 0.20.0 - vcstool: 0.3.0 - wandb: 0.13.11 - watchfiles: 0.19.0 - wcwidth: 0.2.6 - webencodings: 0.5.1 - websocket-client: 1.5.1 - websockets: 11.0 - werkzeug: 2.2.3 - wheel: 0.37.1 - wrapt: 1.15.0 - yarl: 1.8.2 - zipp: 3.8.1 - zmcat: 0.0.11 - zmq: 0.0.0 * System: - OS: Darwin - architecture: - 64bit - - processor: arm - python: 3.10.6 - version: Darwin Kernel Version 21.6.0: Sat Jun 18 17:07:28 PDT 2022; root:xnu-8020.140.41~1/RELEASE_ARM64_T8110

More info

Config file:

# lightning.pytorch==2.0.1
trainer:
  logger:
    class_path: lightning.pytorch.loggers.WandbLogger
    init_args:
      save_dir: ./abcd/TE5T
      project: abcd
      name: TE5T
  callbacks:
    - class_path: lightning.pytorch.callbacks.LearningRateMonitor
    - class_path: lightning.pytorch.callbacks.ModelCheckpoint
      init_args:
        save_top_k: 2
        dirpath: ./abcd/TE5T
        monitor: val_auroc
        save_last: True
  log_every_n_steps: 10

Run with

python3 test_cli.py fit --config configs/logging.yaml  --print_config

cc @carmocca @mauvilsa

mauvilsa commented 1 year ago

@janblumenkamp unfortunately, these are not bugs. You are trying to do things which were never intended to work.

Regarding set_defaults, it is not supposed to work for a parameter of a subclass. What you can do is set the default for the entire subclass, specifying both the class_path and any init_args you want to be different from that class' defaults. That is:

parser.set_defaults({"trainer.logger": {"class_path": ..., "init_args": {...}}})

Regarding link_arguments, having as target an item in a list is not supported. And the example in the description does not make much sense. link_arguments is intended for hard coded (non-configurable) behaviors. The callbacks is a list that can be overridden by the CLI's user. It is not guaranteed that the n-th element of callbacks will always be a class that accepts a savepath parameter. It would make more sense if it were a forced-callback, which is guaranteed to always exist.

janblumenkamp commented 1 year ago

Understood, that makes sense! Thanks a lot for clarifying!