Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
Apache License 2.0
28.38k stars 3.38k forks source link

Argument linking fails when setting model and data via command line instead of passing it to the CLI. #16032

Closed tobemo closed 1 year ago

tobemo commented 1 year ago

Bug description

Argument linking when passing model and datamodule to the cli constructor works; like so: MyLightningCLI(MyModel, MyData).

Argument linking when setting model and datamodule via the command line interface using flags like --model, --data and/or --config raises a ValueError saying Target key "" must be for an individual argument. See code below.

How to reproduce the bug

from pytorch_lightning.cli import LightningCLI
from pytorch_lightning import LightningModule, LightningDataModule

activations = {
    'MaxAbsScaler': 'Sigmoid',
    'Normalizer': 'Sigmoid',
    'QuantileTransformer': 'Sigmoid',
    'RobustScaler': 'PReLU',
    'StandardScaler': 'PReLU',
    'PowerTransformer': 'TanhShrink',

class MyLightningCLI(LightningCLI):
    def add_arguments_to_parser(self, parser):
            "data.scaler", "model.scaler",
            compute_fn=lambda scaler: activations.get(scaler, 'linear'),

class MyModel(LightningModule):
    def __init__(self, activation: str) -> None:

class MyData(LightningDataModule):
    def __init__(self, scaler: str) -> None:

def main():
    # this works :)
    # cli = MyLightningCLI(MyModel, MyData)
    # > python fit .\ --data.scaler MaxAbsScaler
    # > prints Sigmoid

    # this does not work :(
    cli = MyLightningCLI()
    # > python .\ fit --data MyData --data.scaler MaxAbsScaler --model MyModel
    # > ValueError: Target key "model.activation" must be for an individual argument.

if __name__ == '__main__':

Error messages and logs

Traceback (most recent call last):
  File "D:\Users\MC\Documents\ALGO_TRADING\0 BIAS\VAE\", line 50, in <module>
  File "D:\Users\MC\Documents\ALGO_TRADING\0 BIAS\VAE\", line 45, in main
    cli = MyLightningCLI()
  File "D:\Users\MC\Documents\ALGO_TRADING\0 BIAS\VAE\.venv\lib\site-packages\pytorch_lightning\", line 343, in __init__
    self.setup_parser(run, main_kwargs, subparser_kwargs)
  File "D:\Users\MC\Documents\ALGO_TRADING\0 BIAS\VAE\.venv\lib\site-packages\pytorch_lightning\", line 403, in setup_parser
    self._add_subcommands(self.parser, **subparser_kwargs)
  File "D:\Users\MC\Documents\ALGO_TRADING\0 BIAS\VAE\.venv\lib\site-packages\pytorch_lightning\", line 480, in _add_subcommands
    subcommand_parser = self._prepare_subcommand_parser(trainer_class, subcommand, **subparser_kwargs)
  File "D:\Users\MC\Documents\ALGO_TRADING\0 BIAS\VAE\.venv\lib\site-packages\pytorch_lightning\", line 486, in _prepare_subcommand_parser
  File "D:\Users\MC\Documents\ALGO_TRADING\0 BIAS\VAE\.venv\lib\site-packages\pytorch_lightning\", line 439, in _add_arguments
  File "D:\Users\MC\Documents\ALGO_TRADING\0 BIAS\VAE\", line 17, in add_arguments_to_parser
  File "D:\Users\MC\Documents\ALGO_TRADING\0 BIAS\VAE\.venv\lib\site-packages\jsonargparse\", line 379, in link_arguments
    ActionLink(self, source, target, compute_fn, apply_on)
  File "D:\Users\MC\Documents\ALGO_TRADING\0 BIAS\VAE\.venv\lib\site-packages\jsonargparse\", line 141, in __init__
    raise ValueError(f'Target key "{target}" must be for an individual argument.')
ValueError: Target key "model.scaler" must be for an individual argument.


Current environment ``` * CUDA: - GPU: - NVIDIA GeForce GTX 970 - available: True - version: 11.7 * Lightning: - lightning-utilities: 0.3.0 - pytorch-lightning: 1.8.3.post1 - torch: 1.13.0+cu117 - torchaudio: 0.13.0+cu117 - torchmetrics: 0.11.0 - torchvision: 0.14.0+cu117 * Packages: - absl-py: 1.3.0 - aiohttp: 3.8.3 - aiosignal: 1.3.1 - alembic: 1.8.1 - antlr4-python3-runtime: 4.9.3 - asttokens: 2.1.0 - async-timeout: 4.0.2 - attrs: 22.1.0 - autopage: 0.5.1 - backcall: 0.2.0 - cachetools: 5.2.0 - certifi: 2022.9.24 - charset-normalizer: 2.1.1 - cliff: 4.1.0 - cmaes: 0.9.0 - cmd2: 2.4.2 - colorama: 0.4.6 - colorlog: 6.7.0 - commonmark: 0.9.1 - contourpy: 1.0.6 - cycler: 0.11.0 - debugpy: 1.6.3 - decorator: 5.1.1 - docstring-parser: 0.15 - entrypoints: 0.4 - executing: 1.2.0 - fire: 0.4.0 - fonttools: 4.38.0 - frozenlist: 1.3.3 - fsspec: 2022.11.0 - google-auth: 2.15.0 - google-auth-oauthlib: 0.4.6 - greenlet: 2.0.1 - grpcio: 1.51.1 - hydra-core: 1.2.0 - idna: 3.4 - imageio: 2.22.4 - importlib-metadata: 4.13.0 - ipykernel: 6.17.1 - ipython: 8.6.0 - jedi: 0.18.1 - joblib: 1.2.0 - jsonargparse: 4.18.0 - jupyter-client: 7.4.6 - jupyter-core: 5.0.0 - kiwisolver: 1.4.4 - lightning-utilities: 0.3.0 - mako: 1.2.4 - markdown: 3.4.1 - markupsafe: 2.1.1 - matplotlib: 3.6.2 - matplotlib-inline: 0.1.6 - multidict: 6.0.2 - nest-asyncio: 1.5.6 - networkx: 2.8.8 - numpy: 1.23.4 - oauthlib: 3.2.2 - omegaconf: 2.2.3 - optuna: 3.0.4 - packaging: 21.3 - pandas: 1.5.1 - parso: 0.8.3 - pbr: 5.11.0 - pickleshare: 0.7.5 - pillow: 9.3.0 - pip: 22.3.1 - platformdirs: 2.5.4 - prettytable: 3.5.0 - prompt-toolkit: 3.0.32 - protobuf: 3.20.1 - psutil: 5.9.4 - pure-eval: 0.2.2 - pyarrow: 10.0.0 - pyasn1: 0.4.8 - pyasn1-modules: 0.2.8 - pygments: 2.13.0 - pyparsing: 3.0.9 - pyperclip: 1.8.2 - pyreadline3: 3.4.1 - python-dateutil: 2.8.2 - pytorch-lightning: 1.8.3.post1 - pytz: 2022.6 - pywavelets: 1.4.1 - pywin32: 305 - pyyaml: 6.0 - pyzmq: 24.0.1 - requests: 2.28.1 - requests-oauthlib: 1.3.1 - rich: 12.6.0 - rsa: 4.9 - scikit-image: 0.19.3 - scikit-learn: 1.1.3 - scipy: 1.8.1 - setuptools: 58.1.0 - six: 1.16.0 - sqlalchemy: 1.4.44 - stack-data: 0.6.1 - stevedore: 4.1.1 - tensorboard: 2.11.0 - tensorboard-data-server: 0.6.1 - tensorboard-plugin-wit: 1.8.1 - tensorboardx: 2.5.1 - termcolor: 2.1.1 - threadpoolctl: 3.1.0 - tifffile: 2022.10.10 - torch: 1.13.0+cu117 - torchaudio: 0.13.0+cu117 - torchmetrics: 0.11.0 - torchvision: 0.14.0+cu117 - tornado: 6.2 - tqdm: 4.64.1 - traitlets: 5.5.0 - typing-extensions: 4.4.0 - ujson: 5.5.0 - urllib3: 1.26.12 - wcwidth: 0.2.5 - werkzeug: 2.2.2 - wheel: 0.38.4 - yarl: 1.8.1 - zipp: 3.11.0 * System: - OS: Windows - architecture: - 64bit - WindowsPE - processor: AMD64 Family 23 Model 8 Stepping 2, AuthenticAMD - python: 3.10.4 - version: 10.0.19045 ```

More info

No response

cc @carmocca @mauvilsa

samvanstroud commented 1 year ago

See also

mauvilsa commented 1 year ago

In pull request this error message has been changed to make it more clear what the problem is. Note that the link_arguments in the code above is a mistake and should give an error. The correct link would be:

            "data.scaler", "model.init_args.activation",
            compute_fn=lambda scaler: activations.get(scaler, 'linear'),
tobemo commented 1 year ago

Great, this works! For completeness sake I would like to mention that since data.scaler will already be instantiated by the time it goes through the linking process it is no longer a string but an object.

This works:

    "data.scaler", "model.init_args.activation",
    compute_fn=lambda scaler: activations.get(type(scaler).__name__, 'linear'),