Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.38k stars 3.38k forks source link

Argument linking fails when setting model and data via command line instead of passing it to the CLI. #16032

Closed tobemo closed 1 year ago

tobemo commented 1 year ago

Bug description

Argument linking when passing model and datamodule to the cli constructor works; like so: MyLightningCLI(MyModel, MyData).

Argument linking when setting model and datamodule via the command line interface using flags like --model, --data and/or --config raises a ValueError saying Target key "model.foo" must be for an individual argument. See code below.

How to reproduce the bug

from pytorch_lightning.cli import LightningCLI
from pytorch_lightning import LightningModule, LightningDataModule

activations = {
    'MaxAbsScaler': 'Sigmoid',
    'Normalizer': 'Sigmoid',
    'QuantileTransformer': 'Sigmoid',
    'RobustScaler': 'PReLU',
    'StandardScaler': 'PReLU',
    'PowerTransformer': 'TanhShrink',
}

class MyLightningCLI(LightningCLI):
    def add_arguments_to_parser(self, parser):
        parser.link_arguments(
            "data.scaler", "model.scaler",
            compute_fn=lambda scaler: activations.get(scaler, 'linear'),
            apply_on='instantiate'
        )

class MyModel(LightningModule):
    def __init__(self, activation: str) -> None:
        super().__init__()
        self.save_hyperparameters()
        print(activation)

class MyData(LightningDataModule):
    def __init__(self, scaler: str) -> None:
        super().__init__()
        self.save_hyperparameters()
        print(scaler)

def main():
    # this works :)
    # cli = MyLightningCLI(MyModel, MyData)
    # > python fit .\debug.py --data.scaler MaxAbsScaler
    # > prints Sigmoid

    # this does not work :(
    cli = MyLightningCLI()
    # > python .\debug.py fit --data MyData --data.scaler MaxAbsScaler --model MyModel
    # > ValueError: Target key "model.activation" must be for an individual argument.

if __name__ == '__main__':
    main()

Error messages and logs

Traceback (most recent call last):
  File "D:\Users\MC\Documents\ALGO_TRADING\0 BIAS\VAE\debug.py", line 50, in <module>
    main()
  File "D:\Users\MC\Documents\ALGO_TRADING\0 BIAS\VAE\debug.py", line 45, in main
    cli = MyLightningCLI()
  File "D:\Users\MC\Documents\ALGO_TRADING\0 BIAS\VAE\.venv\lib\site-packages\pytorch_lightning\cli.py", line 343, in __init__
    self.setup_parser(run, main_kwargs, subparser_kwargs)
  File "D:\Users\MC\Documents\ALGO_TRADING\0 BIAS\VAE\.venv\lib\site-packages\pytorch_lightning\cli.py", line 403, in setup_parser
    self._add_subcommands(self.parser, **subparser_kwargs)
  File "D:\Users\MC\Documents\ALGO_TRADING\0 BIAS\VAE\.venv\lib\site-packages\pytorch_lightning\cli.py", line 480, in _add_subcommands
    subcommand_parser = self._prepare_subcommand_parser(trainer_class, subcommand, **subparser_kwargs)
  File "D:\Users\MC\Documents\ALGO_TRADING\0 BIAS\VAE\.venv\lib\site-packages\pytorch_lightning\cli.py", line 486, in _prepare_subcommand_parser
    self._add_arguments(parser)
  File "D:\Users\MC\Documents\ALGO_TRADING\0 BIAS\VAE\.venv\lib\site-packages\pytorch_lightning\cli.py", line 439, in _add_arguments
    self.add_arguments_to_parser(parser)
  File "D:\Users\MC\Documents\ALGO_TRADING\0 BIAS\VAE\debug.py", line 17, in add_arguments_to_parser
    parser.link_arguments(
  File "D:\Users\MC\Documents\ALGO_TRADING\0 BIAS\VAE\.venv\lib\site-packages\jsonargparse\link_arguments.py", line 379, in link_arguments
    ActionLink(self, source, target, compute_fn, apply_on)
  File "D:\Users\MC\Documents\ALGO_TRADING\0 BIAS\VAE\.venv\lib\site-packages\jsonargparse\link_arguments.py", line 141, in __init__
    raise ValueError(f'Target key "{target}" must be for an individual argument.')
ValueError: Target key "model.scaler" must be for an individual argument.

Environment

Current environment ``` * CUDA: - GPU: - NVIDIA GeForce GTX 970 - available: True - version: 11.7 * Lightning: - lightning-utilities: 0.3.0 - pytorch-lightning: 1.8.3.post1 - torch: 1.13.0+cu117 - torchaudio: 0.13.0+cu117 - torchmetrics: 0.11.0 - torchvision: 0.14.0+cu117 * Packages: - absl-py: 1.3.0 - aiohttp: 3.8.3 - aiosignal: 1.3.1 - alembic: 1.8.1 - antlr4-python3-runtime: 4.9.3 - asttokens: 2.1.0 - async-timeout: 4.0.2 - attrs: 22.1.0 - autopage: 0.5.1 - backcall: 0.2.0 - cachetools: 5.2.0 - certifi: 2022.9.24 - charset-normalizer: 2.1.1 - cliff: 4.1.0 - cmaes: 0.9.0 - cmd2: 2.4.2 - colorama: 0.4.6 - colorlog: 6.7.0 - commonmark: 0.9.1 - contourpy: 1.0.6 - cycler: 0.11.0 - debugpy: 1.6.3 - decorator: 5.1.1 - docstring-parser: 0.15 - entrypoints: 0.4 - executing: 1.2.0 - fire: 0.4.0 - fonttools: 4.38.0 - frozenlist: 1.3.3 - fsspec: 2022.11.0 - google-auth: 2.15.0 - google-auth-oauthlib: 0.4.6 - greenlet: 2.0.1 - grpcio: 1.51.1 - hydra-core: 1.2.0 - idna: 3.4 - imageio: 2.22.4 - importlib-metadata: 4.13.0 - ipykernel: 6.17.1 - ipython: 8.6.0 - jedi: 0.18.1 - joblib: 1.2.0 - jsonargparse: 4.18.0 - jupyter-client: 7.4.6 - jupyter-core: 5.0.0 - kiwisolver: 1.4.4 - lightning-utilities: 0.3.0 - mako: 1.2.4 - markdown: 3.4.1 - markupsafe: 2.1.1 - matplotlib: 3.6.2 - matplotlib-inline: 0.1.6 - multidict: 6.0.2 - nest-asyncio: 1.5.6 - networkx: 2.8.8 - numpy: 1.23.4 - oauthlib: 3.2.2 - omegaconf: 2.2.3 - optuna: 3.0.4 - packaging: 21.3 - pandas: 1.5.1 - parso: 0.8.3 - pbr: 5.11.0 - pickleshare: 0.7.5 - pillow: 9.3.0 - pip: 22.3.1 - platformdirs: 2.5.4 - prettytable: 3.5.0 - prompt-toolkit: 3.0.32 - protobuf: 3.20.1 - psutil: 5.9.4 - pure-eval: 0.2.2 - pyarrow: 10.0.0 - pyasn1: 0.4.8 - pyasn1-modules: 0.2.8 - pygments: 2.13.0 - pyparsing: 3.0.9 - pyperclip: 1.8.2 - pyreadline3: 3.4.1 - python-dateutil: 2.8.2 - pytorch-lightning: 1.8.3.post1 - pytz: 2022.6 - pywavelets: 1.4.1 - pywin32: 305 - pyyaml: 6.0 - pyzmq: 24.0.1 - requests: 2.28.1 - requests-oauthlib: 1.3.1 - rich: 12.6.0 - rsa: 4.9 - scikit-image: 0.19.3 - scikit-learn: 1.1.3 - scipy: 1.8.1 - setuptools: 58.1.0 - six: 1.16.0 - sqlalchemy: 1.4.44 - stack-data: 0.6.1 - stevedore: 4.1.1 - tensorboard: 2.11.0 - tensorboard-data-server: 0.6.1 - tensorboard-plugin-wit: 1.8.1 - tensorboardx: 2.5.1 - termcolor: 2.1.1 - threadpoolctl: 3.1.0 - tifffile: 2022.10.10 - torch: 1.13.0+cu117 - torchaudio: 0.13.0+cu117 - torchmetrics: 0.11.0 - torchvision: 0.14.0+cu117 - tornado: 6.2 - tqdm: 4.64.1 - traitlets: 5.5.0 - typing-extensions: 4.4.0 - ujson: 5.5.0 - urllib3: 1.26.12 - wcwidth: 0.2.5 - werkzeug: 2.2.2 - wheel: 0.38.4 - yarl: 1.8.1 - zipp: 3.11.0 * System: - OS: Windows - architecture: - 64bit - WindowsPE - processor: AMD64 Family 23 Model 8 Stepping 2, AuthenticAMD - python: 3.10.4 - version: 10.0.19045 ```

More info

No response

cc @carmocca @mauvilsa

samvanstroud commented 1 year ago

See also https://github.com/omni-us/jsonargparse/issues/208

mauvilsa commented 1 year ago

In pull request https://github.com/omni-us/jsonargparse/pull/218 this error message has been changed to make it more clear what the problem is. Note that the link_arguments in the code above is a mistake and should give an error. The correct link would be:

        parser.link_arguments(
            "data.scaler", "model.init_args.activation",
            compute_fn=lambda scaler: activations.get(scaler, 'linear'),
            apply_on='instantiate'
        )
tobemo commented 1 year ago

Great, this works! For completeness sake I would like to mention that since data.scaler will already be instantiated by the time it goes through the linking process it is no longer a string but an object.

This works:

parser.link_arguments(
    "data.scaler", "model.init_args.activation",
    compute_fn=lambda scaler: activations.get(type(scaler).__name__, 'linear'),
    apply_on='instantiate'
)