Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.14k stars 3.37k forks source link

lightning-cli does not support setting deterministic to ``"warn"`` #18184

Closed Galaxy-Husky closed 1 year ago

Galaxy-Husky commented 1 year ago

Bug description

The class Trainer says that setting to "warn" to use deterministic algorithms whenever possible, throwing warnings on operations that don't support deterministic mode (requires PyTorch 1.11+). But I failed to set deterministic to "warn" in the config yaml for lightning-cli.

# lightning.pytorch==2.0.6 jsonargparse==4.23.0
seed_everything: 1234
trainer:
  accelerator: gpu
  strategy: auto
  devices: auto
  num_nodes: 1
  logger: true
  callbacks: null
  fast_dev_run: false
  max_epochs: 10
  min_epochs: null
  max_steps: -1
  min_steps: null
  max_time: null
  limit_train_batches: 10
  limit_val_batches: 2
  limit_test_batches: 2
  limit_predict_batches: 2
  overfit_batches: 0.0
  val_check_interval: null
  check_val_every_n_epoch: 1
  num_sanity_val_steps: 1
  log_every_n_steps: 1
  enable_checkpointing: false
  enable_progress_bar: true
  enable_model_summary: true
  accumulate_grad_batches: 1
  gradient_clip_val: 10
  gradient_clip_algorithm: value
  deterministic: warn
  benchmark: true
  inference_mode: true
  use_distributed_sampler: true
  profiler: null
  detect_anomaly: false
  barebones: false
  plugins: null
  sync_batchnorm: false
  reload_dataloaders_every_n_epochs: 0
  default_root_dir: null
model:
  hidden_dim: 64
  lr: 0.01
data:
  batch_size: 2
optimizer:
  class_path: torch.optim.AdamW
  init_args:
    lr: 0.0001
    betas:
    - 0.9
    - 0.999
    eps: 1.0e-08
    weight_decay: 0.0
    amsgrad: false
    maximize: false
    foreach: null
    capturable: false
    differentiable: false
    fused: null
lr_scheduler:
  class_path: torch.optim.lr_scheduler.CosineAnnealingLR
  init_args:
    T_max: 200
    eta_min: 0.0
    last_epoch: -1
    verbose: false

What version are you seeing the problem on?

v2.0

How to reproduce the bug

from os import path
from typing import Optional, Tuple
import warnings

import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader, random_split

from lightning.pytorch import callbacks, LightningDataModule, LightningModule, Trainer
from lightning.pytorch.cli import LightningCLI
from lightning.pytorch.demos.mnist_datamodule import MNIST
from lightning.pytorch.utilities import rank_zero_only
from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE
from lightning.pytorch import Trainer
import logging

logging.getLogger("lightning.pytorch").setLevel(logging.INFO)
logger = logging.getLogger("lightning.pytorch.core")
logger.addHandler(logging.FileHandler("core.log"))

if _TORCHVISION_AVAILABLE:
    import torchvision
    from torchvision import transforms
    from torchvision.utils import save_image

warnings.filterwarnings('ignore')
DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets")

class ImageSampler(callbacks.Callback):
    def __init__(
        self,
        num_samples: int = 3,
        nrow: int = 8,
        padding: int = 2,
        normalize: bool = True,
        norm_range: Optional[Tuple[int, int]] = None,
        scale_each: bool = False,
        pad_value: int = 0,
    ) -> None:

        if not _TORCHVISION_AVAILABLE:  # pragma: no cover
            raise ModuleNotFoundError("You want to use `torchvision` which is not installed yet.")

        super().__init__()
        self.num_samples = num_samples
        self.nrow = nrow
        self.padding = padding
        self.normalize = normalize
        self.norm_range = norm_range
        self.scale_each = scale_each
        self.pad_value = pad_value

    def _to_grid(self, images):
        return torchvision.utils.make_grid(
            tensor=images,
            nrow=self.nrow,
            padding=self.padding,
            normalize=self.normalize,
            value_range=self.norm_range,
            scale_each=self.scale_each,
            pad_value=self.pad_value,
        )

    @rank_zero_only
    def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
        if not _TORCHVISION_AVAILABLE:
            return

        images, _ = next(iter(DataLoader(trainer.datamodule.mnist_val, batch_size=self.num_samples)))
        images_flattened = images.view(images.size(0), -1)

        # generate images
        with torch.no_grad():
            pl_module.eval()
            images_generated = pl_module(images_flattened.to(pl_module.device))
            pl_module.train()

        if trainer.current_epoch == 0:
            save_image(self._to_grid(images), f"grid_ori_{trainer.current_epoch}.png")
        save_image(self._to_grid(images_generated.reshape(images.shape)), f"grid_generated_{trainer.current_epoch}.png")

class LitAutoEncoder(LightningModule):
    def __init__(self, hidden_dim: int = 64, lr=10e-3):
        super().__init__()
        self.save_hyperparameters()
        self.encoder = nn.Sequential(nn.Linear(28 * 28, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 3))
        self.decoder = nn.Sequential(nn.Linear(3, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 28 * 28))

    def forward(self, x):
        z = self.encoder(x)
        return self.decoder(z)

    def training_step(self, batch, batch_idx):
        return self._common_step(batch, batch_idx, "train")

    def validation_step(self, batch, batch_idx):
        self._common_step(batch, batch_idx, "val")

    def test_step(self, batch, batch_idx):
        self._common_step(batch, batch_idx, "test")

    def predict_step(self, batch, batch_idx, dataloader_idx=None):
        x = self._prepare_batch(batch)
        return self(x)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)

    def _prepare_batch(self, batch):
        x, _ = batch
        return x.view(x.size(0), -1)

    def _common_step(self, batch, batch_idx, stage: str):
        x = self._prepare_batch(batch)
        loss = F.mse_loss(x, self(x))
        self.log(f"{stage}_loss", loss, on_step=True, on_epoch=True)
        return loss

class MyDataModule(LightningDataModule):
    def __init__(self, batch_size: int = 32):
        super().__init__()
        dataset = MNIST(DATASETS_PATH, train=True, download=True, transform=transforms.ToTensor())
        self.mnist_test = MNIST(DATASETS_PATH, train=False, download=True, transform=transforms.ToTensor())
        self.mnist_train, self.mnist_val = random_split(dataset, [55000, 5000])
        self.batch_size = batch_size

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size)

    def predict_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size)

def cli_main():
    trainer_defaults = {"callbacks": ImageSampler(),
                        }
    logger.info('test')
    cli = LightningCLI(
        LitAutoEncoder,
        MyDataModule,
        run=False,
        trainer_defaults=trainer_defaults,
        save_config_kwargs={"overwrite": True},
    )
    print(cli.trainer.precision)
    cli.trainer.fit(cli.model, datamodule=cli.datamodule)
    # cli.trainer.test(ckpt_path="best", datamodule=cli.datamodule)
    # predictions = cli.trainer.predict(ckpt_path="best", datamodule=cli.datamodule)
    # print(predictions[0])

if __name__ == "__main__":
    # cli_lightning_logo()
    cli_main()

Error messages and logs

error: Parser key "trainer.deterministic":
  Does not validate against any of the Union subtypes
  Subtypes: (<class 'bool'>, <class 'NoneType'>)
  Errors:
    - Expected a <class 'bool'>
    - Expected a <class 'NoneType'>
  Given value type: <class 'str'>
  Given value: warn

Environment

Current environment * CUDA: - GPU: - NVIDIA GeForce MX550 - available: True - version: 11.7 * Lightning: - lightning: 2.0.6 - lightning-cloud: 0.5.37 - lightning-utilities: 0.9.0 - pytorch-lightning: 2.0.4 - pytorch-optimizer: 2.10.1 - pytorch-pretrained-bert: 0.6.2 - pytorch-warmup: 0.1.1 - torch: 2.0.0 - torchaudio: 2.0.1+cu117 - torchdata: 0.6.0 - torchinfo: 1.8.0 - torchmetrics: 1.0.0 - torchtext: 0.15.1 - torchvision: 0.15.1 * Packages: - absl-py: 1.3.0 - aiohttp: 3.8.3 - aiosignal: 1.2.0 - alabaster: 0.7.13 - annotated-types: 0.5.0 - ansicon: 1.89.0 - antlr4-python3-runtime: 4.9.3 - anyascii: 0.3.2 - anyio: 3.6.2 - appdirs: 1.4.4 - argon2-cffi: 21.3.0 - argon2-cffi-bindings: 21.2.0 - arrow: 1.2.3 - asttokens: 2.2.1 - async-timeout: 4.0.2 - attrs: 20.2.0 - autopep8: 1.6.0 - babel: 2.12.1 - backcall: 0.2.0 - backoff: 2.2.1 - backports.functools-lru-cache: 1.6.4 - beautifulsoup4: 4.11.2 - bleach: 6.0.0 - blessed: 1.20.0 - blinker: 1.4 - boto3: 1.26.94 - botocore: 1.29.94 - brotlipy: 0.7.0 - cachetools: 4.2.2 - certifi: 2022.12.7 - cffi: 1.15.1 - charset-normalizer: 2.1.1 - click: 8.0.4 - colorama: 0.4.6 - coloredlogs: 15.0.1 - comm: 0.1.2 - contourpy: 1.0.6 - contractions: 0.1.73 - croniter: 1.3.15 - cryptography: 39.0.0 - cycler: 0.11.0 - datasets: 2.2.1 - dateutils: 0.6.12 - debugpy: 1.6.5 - decorator: 5.1.1 - deepdiff: 6.3.1 - defusedxml: 0.7.1 - dill: 0.3.6 - docformatter: 1.5.1 - docstring-parser: 0.15 - docutils: 0.18.1 - emoji: 2.2.0 - entrypoints: 0.4 - exceptiongroup: 1.1.1 - executing: 1.2.0 - fairscale: 0.4.13 - fastapi: 0.100.0 - fastjsonschema: 2.16.2 - filelock: 3.10.0 - flake8: 6.0.0 - flake8-bugbear: 23.3.12 - flit-core: 3.8.0 - fonttools: 4.38.0 - fqdn: 1.5.1 - frozenlist: 1.3.3 - fsspec: 2022.5.0 - fuzzywuzzy: 0.18.0 - gitdb: 4.0.10 - gitdb2: 4.0.2 - gitpython: 3.1.31 - google-api-core: 2.11.0 - google-auth: 2.16.2 - google-auth-oauthlib: 0.4.4 - google-cloud-core: 2.3.2 - google-cloud-storage: 2.7.0 - google-crc32c: 1.5.0 - google-resumable-media: 2.4.1 - googleapis-common-protos: 1.58.0 - grpcio: 1.52.1 - h11: 0.14.0 - huggingface-hub: 0.13.2 - humanfriendly: 10.0 - hydra-core: 1.3.2 - idna: 3.4 - ijson: 3.2.0.post0 - imagesize: 1.4.1 - importlib-metadata: 6.0.0 - importlib-resources: 5.10.2 - iniconfig: 2.0.0 - inquirer: 3.1.3 - iopath: 0.1.10 - ipdb: 0.13.13 - ipykernel: 6.19.4 - ipython: 8.8.0 - ipython-genutils: 0.2.0 - ipywidgets: 8.0.4 - isoduration: 20.11.0 - itsdangerous: 2.1.2 - jedi: 0.18.2 - jinja2: 3.1.2 - jinxed: 1.2.0 - jmespath: 1.0.1 - joblib: 1.2.0 - jsonargparse: 4.23.0 - jsonlines: 3.1.0 - jsonpointer: 2.3 - jsonschema: 4.17.3 - jupyter: 1.0.0 - jupyter-client: 7.4.8 - jupyter-console: 6.4.4 - jupyter-contrib-core: 0.4.2 - jupyter-contrib-nbextensions: 0.7.0 - jupyter-core: 5.1.2 - jupyter-events: 0.6.3 - jupyter-highlight-selected-word: 0.2.0 - jupyter-nbextensions-configurator: 0.6.3 - jupyter-server: 2.2.0 - jupyter-server-terminals: 0.4.4 - jupyterlab-pygments: 0.2.2 - jupyterlab-widgets: 3.0.5 - kiwisolver: 1.4.4 - lightning: 2.0.6 - lightning-cloud: 0.5.37 - lightning-utilities: 0.9.0 - lxml: 4.9.2 - markdown: 3.3.2 - markdown-it-py: 3.0.0 - markupsafe: 2.1.2 - matplotlib: 3.6.2 - matplotlib-inline: 0.1.6 - mccabe: 0.7.0 - mdit-py-plugins: 0.4.0 - mdurl: 0.1.2 - mistune: 2.0.4 - mock: 5.0.1 - mpmath: 1.2.1 - multidict: 6.0.2 - multiprocess: 0.70.14 - munkres: 1.1.4 - myst-parser: 2.0.0 - nb-conda-kernels: 2.3.1 - nbclassic: 0.4.8 - nbclient: 0.7.2 - nbconvert: 7.2.9 - nbformat: 5.7.3 - nest-asyncio: 1.5.6 - networkx: 3.0 - ninja: 1.10.2.4 - nltk: 3.8.1 - notebook: 6.5.2 - notebook-shim: 0.2.2 - numpy: 1.24.2 - oauthlib: 3.2.1 - omegaconf: 2.3.0 - ordered-set: 4.1.0 - packaging: 22.0 - pandas: 1.5.2 - pandocfilters: 1.5.0 - parso: 0.8.3 - pexpect: 4.8.0 - pickleshare: 0.7.5 - pillow: 9.4.0 - pip: 23.0.1 - pkgutil-resolve-name: 1.3.10 - platformdirs: 2.6.2 - pluggy: 1.0.0 - ply: 3.11 - pooch: 1.6.0 - portalocker: 2.7.0 - prometheus-client: 0.16.0 - prompt-toolkit: 3.0.36 - protobuf: 3.20.3 - psutil: 5.9.4 - ptyprocess: 0.7.0 - pure-eval: 0.2.2 - py-gfm: 2.0.0 - py-rouge: 1.1 - pyahocorasick: 2.0.0 - pyarrow: 11.0.0 - pyasn1: 0.4.8 - pyasn1-modules: 0.2.8 - pycodestyle: 2.10.0 - pycparser: 2.21 - pydantic: 2.0.2 - pydantic-core: 2.1.2 - pyflakes: 3.0.1 - pygments: 2.14.0 - pyjwt: 2.4.0 - pyopenssl: 23.0.0 - pyparsing: 3.0.9 - pyqt5: 5.15.7 - pyqt5-sip: 12.11.0 - pyreadline: 2.1 - pyreadline3: 3.4.1 - pyrsistent: 0.19.3 - pysocks: 1.7.1 - pytest: 7.2.2 - pytest-datadir: 1.4.1 - pytest-regressions: 2.4.2 - python-dateutil: 2.8.2 - python-editor: 1.0.4 - python-json-logger: 2.0.4 - python-multipart: 0.0.6 - pytorch-lightning: 2.0.4 - pytorch-optimizer: 2.10.1 - pytorch-pretrained-bert: 0.6.2 - pytorch-warmup: 0.1.1 - pytz: 2022.7.1 - pywin32: 304 - pywinpty: 2.0.10 - pyyaml: 6.0 - pyzmq: 24.0.1 - qtconsole: 5.4.0 - qtpy: 2.3.0 - rake-nltk: 1.0.6 - readchar: 4.0.5 - regex: 2022.10.31 - requests: 2.28.1 - requests-mock: 1.10.0 - requests-oauthlib: 1.3.0 - responses: 0.18.0 - rfc3339-validator: 0.1.4 - rfc3986-validator: 0.1.1 - rich: 13.4.2 - rsa: 4.7.2 - rwkv: 0.8.0 - s3transfer: 0.6.0 - scikit-learn: 1.2.2 - scipy: 1.10.0 - send2trash: 1.8.0 - sentencepiece: 0.1.98 - setuptools: 65.6.3 - sh: 2.0.2 - sip: 6.7.5 - six: 1.16.0 - smmap: 5.0.0 - sniffio: 1.3.0 - snowballstemmer: 2.2.0 - soupsieve: 2.3.2.post1 - sphinx: 6.2.1 - sphinx-autodoc-typehints: 1.10.3 - sphinx-rtd-theme: 1.2.2 - sphinxcontrib-applehelp: 1.0.4 - sphinxcontrib-devhelp: 1.0.2 - sphinxcontrib-htmlhelp: 2.0.1 - sphinxcontrib-jquery: 4.1 - sphinxcontrib-jsmath: 1.0.1 - sphinxcontrib-qthelp: 1.0.3 - sphinxcontrib-serializinghtml: 1.1.5 - stack-data: 0.6.2 - starlette: 0.27.0 - starsessions: 1.3.0 - subword-nmt: 0.3.8 - sympy: 1.11.1 - tensorboard: 2.12.0 - tensorboard-data-server: 0.7.0 - tensorboard-plugin-wit: 1.8.1 - tensorboardx: 2.6 - terminado: 0.15.0 - textsearch: 0.0.24 - threadpoolctl: 3.1.0 - tinycss2: 1.2.1 - tokenizers: 0.13.2 - toml: 0.10.2 - tomli: 2.0.1 - torch: 2.0.0 - torchaudio: 2.0.1+cu117 - torchdata: 0.6.0 - torchinfo: 1.8.0 - torchmetrics: 1.0.0 - torchtext: 0.15.1 - torchvision: 0.15.1 - tornado: 6.2 - tqdm: 4.62.3 - traitlets: 5.8.0 - transformers: 4.27.1 - typeshed-client: 2.3.0 - typing-extensions: 4.7.1 - ujson: 5.7.0 - unicodedata2: 15.0.0 - unidecode: 1.3.6 - untokenize: 0.1.1 - uri-template: 1.2.0 - urllib3: 1.26.13 - uvicorn: 0.22.0 - wcwidth: 0.2.5 - webcolors: 1.13 - webencodings: 0.5.1 - websocket-client: 1.5.0 - websockets: 11.0.3 - werkzeug: 2.2.2 - wheel: 0.38.4 - widgetsnbextension: 4.0.5 - win-inet-pton: 1.1.0 - xxhash: 3.2.0 - yapf: 0.32.0 - yarl: 1.8.1 - zipp: 3.12.0 * System: - OS: Windows - architecture: - 64bit - WindowsPE - processor: Intel64 Family 6 Model 154 Stepping 4, GenuineIntel - python: 3.10.0 - release: 10 - version: 10.0.22621

More info

No response

cc @carmocca @mauvilsa

awaelchli commented 1 year ago

I tried to reproduce this but without success. I copied your code sample (thanks a lot for providing it!!) and ran with Lightning 2.0.6, jsonargparse 4.23.0 and Python 3.10. The training starts and I can also set the argument via --trainer.deterministic=warn. @Galaxy-Husky Would you mind double checking that your Python is picking up the right environment / package versions?

Galaxy-Husky commented 1 year ago

@awaelchli Thank you for your efforts to reproduce this ! I have checked the current environment and filled it out in the "Environment" above. To debug, I added export JSONARGPARSE_DEBUG=true before running the command. Here are error messages:

C:\Users\Ping\mambaforge\envs\pytorch\lib\site-packages\torchaudio\backend\utils.py:74: UserWarning: No audio backend is available.
  warnings.warn("No audio backend is available.")
test
2023-07-29 02:21:57,787 - LightningArgumentParser - DEBUG - Skipping parameter "precision" from "lightning.Trainer.__init__" because of: Unsupported type hint typing.Union[typing.Literal[64, 32, 16], typing.Literal['transformer-engine', '16-true', '16-mixed', 'bf16-true', 'bf16-mixed', '32-true', '64-true'], typing.Literal['64', '32', '16', 'bf16']].
2023-07-29 02:21:57,792 - LightningArgumentParser - DEBUG - Discarding unsupported subtypes {typing.Literal['warn']} from typing.Union[bool, typing.Literal['warn'], NoneType]      
2023-07-29 02:21:57,807 - LightningArgumentParser - DEBUG - Loaded default values from parser
2023-07-29 02:21:58,614 - LightningArgumentParser - DEBUG - Skipping parameter "params" from "torch.optim.AdamW.__init__" because of: Parameter requested to be skipped.
2023-07-29 02:21:58,616 - LightningArgumentParser - DEBUG - Parsed object: Namespace(lr=0.0001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0, amsgrad=False, maximize=False, foreach=None, capturable=False, differentiable=False, fused=None)
2023-07-29 02:21:58,622 - LightningArgumentParser - DEBUG - Skipping parameter "optimizer" from "torch.optim.lr_scheduler.CosineAnnealingLR.__init__" because of: Parameter requested to be skipped.
2023-07-29 02:21:58,624 - LightningArgumentParser - DEBUG - Parsed object: Namespace(T_max=200, eta_min=0.0, last_epoch=-1, verbose=False)
2023-07-29 02:21:58,629 - LightningArgumentParser - DEBUG - Parsed yaml string: # lightning.pytorch==2.0.6 jsonargparse==4.23.0
seed_everything: 1234
trainer:
  accelerator: gpu
  strategy: auto
  devices: auto
  num_nodes: 1
  logger: true
  callbacks: null
  fast_dev_run: false
  max_epochs: 10
  min_epochs: null
  max_steps: -1
  min_steps: null
  max_time: null
  limit_train_batches: 10
  limit_val_batches: 2
  limit_test_batches: 2
  limit_predict_batches: 2
  overfit_batches: 0.0
  val_check_interval: null
  check_val_every_n_epoch: 1
  num_sanity_val_steps: 1
  log_every_n_steps: 1
  enable_checkpointing: false
  enable_progress_bar: true
  enable_model_summary: true
  accumulate_grad_batches: 1
  gradient_clip_val: 10
  gradient_clip_algorithm: value
  deterministic: true
  benchmark: true
  inference_mode: true
  use_distributed_sampler: true
  profiler: null
  detect_anomaly: false
  barebones: false
  plugins: null
  sync_batchnorm: false
  reload_dataloaders_every_n_epochs: 0
  default_root_dir: null
model:
  hidden_dim: 64
  lr: 0.01
data:
  batch_size: 2
optimizer:
  class_path: torch.optim.AdamW
  init_args:
    lr: 0.0001
    betas:
    - 0.9
    - 0.999
    eps: 1.0e-08
    weight_decay: 0.0
    amsgrad: false
    maximize: false
    foreach: null
    capturable: false
    differentiable: false
    fused: null
lr_scheduler:
  class_path: torch.optim.lr_scheduler.CosineAnnealingLR
  init_args:
    T_max: 200
    eta_min: 0.0
    last_epoch: -1
    verbose: false

2023-07-29 02:21:58,630 - LightningArgumentParser - DEBUG - Parsed configuration from path: config.yaml
2023-07-29 02:21:58,632 - LightningArgumentParser - ERROR - Parser key "trainer.deterministic":
  Does not validate against any of the Union subtypes
  Subtypes: (<class 'bool'>, <class 'NoneType'>)
  Errors:
    - Expected a <class 'bool'>
    - Expected a <class 'NoneType'>
  Given value type: <class 'str'>
  Given value: warn
2023-07-29 02:21:58,633 - LightningArgumentParser - DEBUG - Debug enabled, thus raising exception instead of exit.
Traceback (most recent call last):
  File "C:\Users\Ping\mambaforge\envs\pytorch\lib\site-packages\jsonargparse\_typehints.py", line 656, in adapt_typehints
    vals.append(adapt_typehints(val, subtypehint, **adapt_kwargs))
  File "C:\Users\Ping\mambaforge\envs\pytorch\lib\site-packages\jsonargparse\_typehints.py", line 615, in adapt_typehints
    raise_unexpected_value(f"Expected a {typehint}", val)
  File "C:\Users\Ping\mambaforge\envs\pytorch\lib\site-packages\jsonargparse\_typehints.py", line 554, in raise_unexpected_value
    raise ValueError(message) from exception
ValueError: Expected a <class 'bool'>. Got value: warn

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "C:\Users\Ping\mambaforge\envs\pytorch\lib\site-packages\jsonargparse\_typehints.py", line 479, in _check_type
    raise ex
  File "C:\Users\Ping\mambaforge\envs\pytorch\lib\site-packages\jsonargparse\_typehints.py", line 467, in _check_type
    val = adapt_typehints(val, self._typehint, **kwargs)
  File "C:\Users\Ping\mambaforge\envs\pytorch\lib\site-packages\jsonargparse\_typehints.py", line 661, in adapt_typehints
    raise_union_unexpected_value(typehint, val, vals)
  File "C:\Users\Ping\mambaforge\envs\pytorch\lib\site-packages\jsonargparse\_typehints.py", line 562, in raise_union_unexpected_value
    raise ValueError(
ValueError: Does not validate against any of the Union subtypes
Subtypes: (<class 'bool'>, <class 'NoneType'>)
Errors:
  - Expected a <class 'bool'>
  - Expected a <class 'NoneType'>
Given value type: <class 'str'>
Given value: warn

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "C:\Users\Ping\mambaforge\envs\pytorch\lib\site-packages\jsonargparse\_core.py", line 385, in parse_args
    cfg, unk = self.parse_known_args(args=args, namespace=cfg)
  File "C:\Users\Ping\mambaforge\envs\pytorch\lib\site-packages\jsonargparse\_core.py", line 256, in parse_known_args
    namespace, args = self._parse_known_args(args, namespace)
  File "C:\Users\Ping\mambaforge\envs\pytorch\lib\argparse.py", line 2062, in _parse_known_args
    start_index = consume_optional(start_index)
  File "C:\Users\Ping\mambaforge\envs\pytorch\lib\argparse.py", line 2002, in consume_optional
    take_action(action, args, option_string)
  File "C:\Users\Ping\mambaforge\envs\pytorch\lib\argparse.py", line 1930, in take_action
    action(self, namespace, argument_values, option_string)
  File "C:\Users\Ping\mambaforge\envs\pytorch\lib\site-packages\jsonargparse\_typehints.py", line 429, in __call__
    val = self._check_type(val, append=append, cfg=cfg)
  File "C:\Users\Ping\mambaforge\envs\pytorch\lib\site-packages\jsonargparse\_typehints.py", line 489, in _check_type
    raise TypeError(f'Parser key "{self.dest}"{elem}:\n{error}') from ex
TypeError: Parser key "trainer.deterministic":
  Does not validate against any of the Union subtypes
  Subtypes: (<class 'bool'>, <class 'NoneType'>)
  Errors:
    - Expected a <class 'bool'>
    - Expected a <class 'NoneType'>
  Given value type: <class 'str'>
  Given value: warn

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "C:\Users\Ping\PycharmProjects\lightning-examples-pytorch-basics\autoencoder.py", line 164, in <module>
    cli_main()
  File "C:\Users\Ping\PycharmProjects\lightning-examples-pytorch-basics\autoencoder.py", line 148, in cli_main
    cli = LightningCLI(
  File "C:\Users\Ping\mambaforge\envs\pytorch\lib\site-packages\lightning\pytorch\cli.py", line 369, in __init__
    self.parse_arguments(self.parser, args)
  File "C:\Users\Ping\mambaforge\envs\pytorch\lib\site-packages\lightning\pytorch\cli.py", line 518, in parse_arguments
    self.config = parser.parse_args(args)
  Errors:
    - Expected a <class 'bool'>
    - Expected a <class 'NoneType'>
  Given value type: <class 'str'>
  Given value: warn

I'm not sure if some of the package versions are wrong.

mauvilsa commented 1 year ago

From the debug logs it is clear that the problem is typing.Literal. It is the same as in #18183, so there is no need to keep both issues open.

The problem is identical to the one fixed by https://github.com/omni-us/jsonargparse/pull/328, however, that was only observed on python 3.9. I extended the ci/cd testing to include python 3.10 and windows, but still, it did not fail in that case, see https://github.com/omni-us/jsonargparse/actions/runs/5710605690/job/15471015986?pr=338. Finally I noticed in typing_extensions the comment Literal bug was fixed in 3.11.0, 3.10.1 and 3.9.8.

So it seems the issue is related to a bug in old CPython releases, 3.10.0 being one of them. I have changed the fix so that it hopefully covers all cases (https://github.com/omni-us/jsonargparse/pull/338). However, I have no way of testing this. @Galaxy-Husky could you please try it out? Just do the following and run your code again:

pip uninstall -y jsonargparse
pip install "jsonargparse @ https://github.com/omni-us/jsonargparse/zipball/improve-testing"
Galaxy-Husky commented 1 year ago

@mauvilsa Thanks a lot!! The problem has been solved.