lightning-cli does not support setting deterministic to ``"warn"`` #18184

Closed Galaxy-Husky closed 1 year ago

Galaxy-Husky commented 1 year ago

Bug description

The class Trainer says that setting to "warn" to use deterministic algorithms whenever possible, throwing warnings on operations that don't support deterministic mode (requires PyTorch 1.11+). But I failed to set deterministic to "warn" in the config yaml for lightning-cli.

# lightning.pytorch==2.0.6 jsonargparse==4.23.0
seed_everything: 1234
  accelerator: gpu
  strategy: auto
  devices: auto
  num_nodes: 1
  logger: true
  callbacks: null
  fast_dev_run: false
  max_epochs: 10
  min_epochs: null
  max_steps: -1
  min_steps: null
  max_time: null
  limit_train_batches: 10
  limit_val_batches: 2
  limit_test_batches: 2
  limit_predict_batches: 2
  overfit_batches: 0.0
  val_check_interval: null
  check_val_every_n_epoch: 1
  num_sanity_val_steps: 1
  log_every_n_steps: 1
  enable_checkpointing: false
  enable_progress_bar: true
  enable_model_summary: true
  accumulate_grad_batches: 1
  gradient_clip_val: 10
  gradient_clip_algorithm: value
  deterministic: warn
  benchmark: true
  inference_mode: true
  use_distributed_sampler: true
  profiler: null
  detect_anomaly: false
  barebones: false
  plugins: null
  sync_batchnorm: false
  reload_dataloaders_every_n_epochs: 0
  default_root_dir: null
  hidden_dim: 64
  lr: 0.01
  batch_size: 2
  class_path: torch.optim.AdamW
    lr: 0.0001
    - 0.9
    - 0.999
    eps: 1.0e-08
    weight_decay: 0.0
    amsgrad: false
    maximize: false
    foreach: null
    capturable: false
    differentiable: false
    fused: null
  class_path: torch.optim.lr_scheduler.CosineAnnealingLR
    T_max: 200
    eta_min: 0.0
    last_epoch: -1
    verbose: false

What version are you seeing the problem on?


How to reproduce the bug

from os import path
from typing import Optional, Tuple
import warnings

import torch
import torch.nn.functional as F
from torch import nn
from import DataLoader, random_split

from lightning.pytorch import callbacks, LightningDataModule, LightningModule, Trainer
from lightning.pytorch.cli import LightningCLI
from lightning.pytorch.demos.mnist_datamodule import MNIST
from lightning.pytorch.utilities import rank_zero_only
from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE
from lightning.pytorch import Trainer
import logging

logger = logging.getLogger("lightning.pytorch.core")

    import torchvision
    from torchvision import transforms
    from torchvision.utils import save_image

DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets")

class ImageSampler(callbacks.Callback):
    def __init__(
        num_samples: int = 3,
        nrow: int = 8,
        padding: int = 2,
        normalize: bool = True,
        norm_range: Optional[Tuple[int, int]] = None,
        scale_each: bool = False,
        pad_value: int = 0,
    ) -> None:

        if not _TORCHVISION_AVAILABLE:  # pragma: no cover
            raise ModuleNotFoundError("You want to use `torchvision` which is not installed yet.")

        self.num_samples = num_samples
        self.nrow = nrow
        self.padding = padding
        self.normalize = normalize
        self.norm_range = norm_range
        self.scale_each = scale_each
        self.pad_value = pad_value

    def _to_grid(self, images):
        return torchvision.utils.make_grid(

    def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:

        images, _ = next(iter(DataLoader(trainer.datamodule.mnist_val, batch_size=self.num_samples)))
        images_flattened = images.view(images.size(0), -1)

        # generate images
        with torch.no_grad():
            images_generated = pl_module(

        if trainer.current_epoch == 0:
            save_image(self._to_grid(images), f"grid_ori_{trainer.current_epoch}.png")
        save_image(self._to_grid(images_generated.reshape(images.shape)), f"grid_generated_{trainer.current_epoch}.png")

class LitAutoEncoder(LightningModule):
    def __init__(self, hidden_dim: int = 64, lr=10e-3):
        self.encoder = nn.Sequential(nn.Linear(28 * 28, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 3))
        self.decoder = nn.Sequential(nn.Linear(3, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 28 * 28))

    def forward(self, x):
        z = self.encoder(x)
        return self.decoder(z)

    def training_step(self, batch, batch_idx):
        return self._common_step(batch, batch_idx, "train")

    def validation_step(self, batch, batch_idx):
        self._common_step(batch, batch_idx, "val")

    def test_step(self, batch, batch_idx):
        self._common_step(batch, batch_idx, "test")

    def predict_step(self, batch, batch_idx, dataloader_idx=None):
        x = self._prepare_batch(batch)
        return self(x)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(),

    def _prepare_batch(self, batch):
        x, _ = batch
        return x.view(x.size(0), -1)

    def _common_step(self, batch, batch_idx, stage: str):
        x = self._prepare_batch(batch)
        loss = F.mse_loss(x, self(x))
        self.log(f"{stage}_loss", loss, on_step=True, on_epoch=True)
        return loss

class MyDataModule(LightningDataModule):
    def __init__(self, batch_size: int = 32):
        dataset = MNIST(DATASETS_PATH, train=True, download=True, transform=transforms.ToTensor())
        self.mnist_test = MNIST(DATASETS_PATH, train=False, download=True, transform=transforms.ToTensor())
        self.mnist_train, self.mnist_val = random_split(dataset, [55000, 5000])
        self.batch_size = batch_size

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size)

    def predict_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size)

def cli_main():
    trainer_defaults = {"callbacks": ImageSampler(),
    cli = LightningCLI(
        save_config_kwargs={"overwrite": True},
    print(cli.trainer.precision), datamodule=cli.datamodule)
    # cli.trainer.test(ckpt_path="best", datamodule=cli.datamodule)
    # predictions = cli.trainer.predict(ckpt_path="best", datamodule=cli.datamodule)
    # print(predictions[0])

if __name__ == "__main__":
    # cli_lightning_logo()

Error messages and logs

error: Parser key "trainer.deterministic":
  Does not validate against any of the Union subtypes
  Subtypes: (<class 'bool'>, <class 'NoneType'>)
    - Expected a <class 'bool'>
    - Expected a <class 'NoneType'>
  Given value type: <class 'str'>
  Given value: warn


More info

No response

cc @carmocca @mauvilsa

awaelchli commented 1 year ago

I tried to reproduce this but without success. I copied your code sample (thanks a lot for providing it!!) and ran with Lightning 2.0.6, jsonargparse 4.23.0 and Python 3.10. The training starts and I can also set the argument via --trainer.deterministic=warn. @Galaxy-Husky Would you mind double checking that your Python is picking up the right environment / package versions?

Galaxy-Husky commented 1 year ago

@awaelchli Thank you for your efforts to reproduce this ! I have checked the current environment and filled it out in the "Environment" above. To debug, I added export JSONARGPARSE_DEBUG=true before running the command. Here are error messages:

C:\Users\Ping\mambaforge\envs\pytorch\lib\site-packages\torchaudio\backend\ UserWarning: No audio backend is available.
  warnings.warn("No audio backend is available.")
2023-07-29 02:21:57,787 - LightningArgumentParser - DEBUG - Skipping parameter "precision" from "lightning.Trainer.__init__" because of: Unsupported type hint typing.Union[typing.Literal[64, 32, 16], typing.Literal['transformer-engine', '16-true', '16-mixed', 'bf16-true', 'bf16-mixed', '32-true', '64-true'], typing.Literal['64', '32', '16', 'bf16']].
2023-07-29 02:21:57,792 - LightningArgumentParser - DEBUG - Discarding unsupported subtypes {typing.Literal['warn']} from typing.Union[bool, typing.Literal['warn'], NoneType]      
2023-07-29 02:21:57,807 - LightningArgumentParser - DEBUG - Loaded default values from parser
2023-07-29 02:21:58,614 - LightningArgumentParser - DEBUG - Skipping parameter "params" from "torch.optim.AdamW.__init__" because of: Parameter requested to be skipped.
2023-07-29 02:21:58,616 - LightningArgumentParser - DEBUG - Parsed object: Namespace(lr=0.0001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0, amsgrad=False, maximize=False, foreach=None, capturable=False, differentiable=False, fused=None)
2023-07-29 02:21:58,622 - LightningArgumentParser - DEBUG - Skipping parameter "optimizer" from "torch.optim.lr_scheduler.CosineAnnealingLR.__init__" because of: Parameter requested to be skipped.
2023-07-29 02:21:58,624 - LightningArgumentParser - DEBUG - Parsed object: Namespace(T_max=200, eta_min=0.0, last_epoch=-1, verbose=False)
2023-07-29 02:21:58,629 - LightningArgumentParser - DEBUG - Parsed yaml string: # lightning.pytorch==2.0.6 jsonargparse==4.23.0
seed_everything: 1234
  accelerator: gpu
  strategy: auto
  devices: auto
  num_nodes: 1
  logger: true
  callbacks: null
  fast_dev_run: false
  max_epochs: 10
  min_epochs: null
  max_steps: -1
  min_steps: null
  max_time: null
  limit_train_batches: 10
  limit_val_batches: 2
  limit_test_batches: 2
  limit_predict_batches: 2
  overfit_batches: 0.0
  val_check_interval: null
  check_val_every_n_epoch: 1
  num_sanity_val_steps: 1
  log_every_n_steps: 1
  enable_checkpointing: false
  enable_progress_bar: true
  enable_model_summary: true
  accumulate_grad_batches: 1
  gradient_clip_val: 10
  gradient_clip_algorithm: value
  deterministic: true
  benchmark: true
  inference_mode: true
  use_distributed_sampler: true
  profiler: null
  detect_anomaly: false
  barebones: false
  plugins: null
  sync_batchnorm: false
  reload_dataloaders_every_n_epochs: 0
  default_root_dir: null
  hidden_dim: 64
  lr: 0.01
  batch_size: 2
  class_path: torch.optim.AdamW
    lr: 0.0001
    - 0.9
    - 0.999
    eps: 1.0e-08
    weight_decay: 0.0
    amsgrad: false
    maximize: false
    foreach: null
    capturable: false
    differentiable: false
    fused: null
  class_path: torch.optim.lr_scheduler.CosineAnnealingLR
    T_max: 200
    eta_min: 0.0
    last_epoch: -1
    verbose: false

2023-07-29 02:21:58,630 - LightningArgumentParser - DEBUG - Parsed configuration from path: config.yaml
2023-07-29 02:21:58,632 - LightningArgumentParser - ERROR - Parser key "trainer.deterministic":
  Does not validate against any of the Union subtypes
  Subtypes: (<class 'bool'>, <class 'NoneType'>)
    - Expected a <class 'bool'>
    - Expected a <class 'NoneType'>
  Given value type: <class 'str'>
  Given value: warn
2023-07-29 02:21:58,633 - LightningArgumentParser - DEBUG - Debug enabled, thus raising exception instead of exit.
Traceback (most recent call last):
  File "C:\Users\Ping\mambaforge\envs\pytorch\lib\site-packages\jsonargparse\", line 656, in adapt_typehints
    vals.append(adapt_typehints(val, subtypehint, **adapt_kwargs))
  File "C:\Users\Ping\mambaforge\envs\pytorch\lib\site-packages\jsonargparse\", line 615, in adapt_typehints
    raise_unexpected_value(f"Expected a {typehint}", val)
  File "C:\Users\Ping\mambaforge\envs\pytorch\lib\site-packages\jsonargparse\", line 554, in raise_unexpected_value
    raise ValueError(message) from exception
ValueError: Expected a <class 'bool'>. Got value: warn

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "C:\Users\Ping\mambaforge\envs\pytorch\lib\site-packages\jsonargparse\", line 479, in _check_type
    raise ex
  File "C:\Users\Ping\mambaforge\envs\pytorch\lib\site-packages\jsonargparse\", line 467, in _check_type
    val = adapt_typehints(val, self._typehint, **kwargs)
  File "C:\Users\Ping\mambaforge\envs\pytorch\lib\site-packages\jsonargparse\", line 661, in adapt_typehints
    raise_union_unexpected_value(typehint, val, vals)
  File "C:\Users\Ping\mambaforge\envs\pytorch\lib\site-packages\jsonargparse\", line 562, in raise_union_unexpected_value
    raise ValueError(
ValueError: Does not validate against any of the Union subtypes
Subtypes: (<class 'bool'>, <class 'NoneType'>)
  - Expected a <class 'bool'>
  - Expected a <class 'NoneType'>
Given value type: <class 'str'>
Given value: warn

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "C:\Users\Ping\mambaforge\envs\pytorch\lib\site-packages\jsonargparse\", line 385, in parse_args
    cfg, unk = self.parse_known_args(args=args, namespace=cfg)
  File "C:\Users\Ping\mambaforge\envs\pytorch\lib\site-packages\jsonargparse\", line 256, in parse_known_args
    namespace, args = self._parse_known_args(args, namespace)
  File "C:\Users\Ping\mambaforge\envs\pytorch\lib\", line 2062, in _parse_known_args
    start_index = consume_optional(start_index)
  File "C:\Users\Ping\mambaforge\envs\pytorch\lib\", line 2002, in consume_optional
    take_action(action, args, option_string)
  File "C:\Users\Ping\mambaforge\envs\pytorch\lib\", line 1930, in take_action
    action(self, namespace, argument_values, option_string)
  File "C:\Users\Ping\mambaforge\envs\pytorch\lib\site-packages\jsonargparse\", line 429, in __call__
    val = self._check_type(val, append=append, cfg=cfg)
  File "C:\Users\Ping\mambaforge\envs\pytorch\lib\site-packages\jsonargparse\", line 489, in _check_type
    raise TypeError(f'Parser key "{self.dest}"{elem}:\n{error}') from ex
TypeError: Parser key "trainer.deterministic":
  Does not validate against any of the Union subtypes
  Subtypes: (<class 'bool'>, <class 'NoneType'>)
    - Expected a <class 'bool'>
    - Expected a <class 'NoneType'>
  Given value type: <class 'str'>
  Given value: warn

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "C:\Users\Ping\PycharmProjects\lightning-examples-pytorch-basics\", line 164, in <module>
  File "C:\Users\Ping\PycharmProjects\lightning-examples-pytorch-basics\", line 148, in cli_main
    cli = LightningCLI(
  File "C:\Users\Ping\mambaforge\envs\pytorch\lib\site-packages\lightning\pytorch\", line 369, in __init__
    self.parse_arguments(self.parser, args)
  File "C:\Users\Ping\mambaforge\envs\pytorch\lib\site-packages\lightning\pytorch\", line 518, in parse_arguments
    self.config = parser.parse_args(args)
    - Expected a <class 'bool'>
    - Expected a <class 'NoneType'>
  Given value type: <class 'str'>
  Given value: warn

I'm not sure if some of the package versions are wrong.

mauvilsa commented 1 year ago

From the debug logs it is clear that the problem is typing.Literal. It is the same as in #18183, so there is no need to keep both issues open.

The problem is identical to the one fixed by, however, that was only observed on python 3.9. I extended the ci/cd testing to include python 3.10 and windows, but still, it did not fail in that case, see Finally I noticed in typing_extensions the comment Literal bug was fixed in 3.11.0, 3.10.1 and 3.9.8.

So it seems the issue is related to a bug in old CPython releases, 3.10.0 being one of them. I have changed the fix so that it hopefully covers all cases ( However, I have no way of testing this. @Galaxy-Husky could you please try it out? Just do the following and run your code again:

pip uninstall -y jsonargparse
pip install "jsonargparse @"
Galaxy-Husky commented 1 year ago

@mauvilsa Thanks a lot!! The problem has been solved.