Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.25k stars 3.38k forks source link

ValueError: dictionary update sequence element #0 has length 1; 2 is required #19730

Closed pau-altur closed 4 months ago

pau-altur commented 6 months ago

Bug description

I am trying to train a Lightning model that inherits from pl.LightningModule and implements a simple feed-forward network. The issue is that when I run it, it spits out the below error trace coming from trainer.fit(). I found this very similar issue, where downgrading to torchmetrics<=0.5.0 fixed the issue, but that is not possible in my case as v2.2.0 of pytorch-lightning is not compatible with such an old version of torchmetrics. I tried downgrading to 0.7., the oldest compatible version, but it led to a different error also in the trainer.fit method.

Thanks for your attention and I would appreciate any help with this.

What version are you seeing the problem on?

v2.2

How to reproduce the bug

Below is the model class definition

import pytorch_lightning as pl
import torch
import numpy as np
from torch.nn import MSELoss, L1Loss
from torchmetrics import R2Score

torch.random.manual_seed(123)

class LightningModelSimple(pl.LightningModule):
    def __init__(
        self,
        latent_model,
        readout_model=None,
        losses={},
        metrics=[],
        gpu=True,
        learning_rate=0.001,
        weight_decay=0.0,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.latent_model = latent_model
        if readout_model is None:
            self.readout_model = torch.nn.Identity()
        else:
            self.readout_model = readout_model

        # losses
        if "target" in losses:
            self.loss_target = losses["target"]
        else:
            self.loss_target = None
        if "latent_target" in losses:
            self.loss_latent_target = losses["latent_target"]
            self.weight_loss_latent_target = losses["weight_loss_latent_target"]
        else:
            self.loss_latent_target = None

        self.gpu = gpu
        self.metrics = metrics
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay

    def forward(self, x):
        x_latent = self.latent_model(x)
        y = self.readout_model(x_latent)
        return y

    def step(self, partition, batch, batch_idx):
        spectra, target_glucose = batch

        # get latent predictions
        self.pred_latent = self.latent_model(spectra.float())

        # get glucose predictions
        self.pred_glucose = self.readout_model(self.pred_latent)

        # compute losses
        loss = 0
        if self.loss_target is not None:
            loss += self.loss_target(self.pred_glucose, target_glucose)
            self.log(partition + "_loss_target", loss, on_epoch=True)

        if self.loss_latent_target is not None:
            loss_latent_target = (
                self.weight_loss_latent_target
                * self.loss_latent_target(self.pred_latent, target_glucose.unsqueeze(1))
            )
            self.log(
                partition + "_loss_latent_target", loss_latent_target, on_epoch=True
            )
            loss += loss_latent_target

        self.log(partition + "_loss_total", loss, on_epoch=True)
        for metric_name, metric in self.metrics:
            self.log(
                partition + "_" + metric_name,
                metric(self.pred_glucose, target_glucose),
                on_epoch=True,
            )

        return loss

    def training_step(self, batch, batch_idx):
        return self.step("train", batch, batch_idx)

    def validation_step(self, batch, batch_idx):
        return self.step("val", batch, batch_idx)

    def test_step(self, batch, batch_idx):
        return self.step("test", batch, batch_idx)

    def configure_optimizers(self):
        return torch.optim.Adam(
            self.parameters(),
            lr=self.hparams.learning_rate,
            weight_decay=self.hparams.weight_decay,
        )

This should go in a different file called helpers.py

def log_parameter(params, parser, param_name=""):
    if isinstance(params, dict):
        for key in params.keys():
            if key == "class_path":
                parser = log_parameter(params[key], parser, param_name)
            else:
                parser = log_parameter(params[key], parser, key)
    else:
        parser.add_argument("--" + param_name, type=type(params), default=params)
    return parser

def update(config_data, params):
    for k, v in params.items():
        if isinstance(v, collections.abc.Mapping):
            config_data[k] = update(config_data.get(k, {}), v)
        else:
            config_data[k] = v
    return config_data

def train_model(config_file, **kwargs):
    loader = yaml.SafeLoader
    with open(config_file, "r") as stream:
        config_data = yaml.load(stream, Loader=loader)

    if "params" in kwargs:
        config_data = update(config_data, kwargs["params"])
    if "latent_model" in kwargs:
        config_data["lightning_model"]["init_args"]["latent_model"] = kwargs[
            "latent_model"
        ]

    # experiment_name = config_data["experiment_name"]
    n_epochs = config_data["n_epochs"]

    pl.seed_everything(1234)

    # add arguments to parser
    parser = ArgumentParser(conflict_handler="resolve")
    parser.add_argument(
        "--auto-select-gpus", default=True, help="run automatically on GPU if available"
    )
    parser.add_argument("--max-epochs", default=n_epochs, type=int)
    parser.add_argument("gpus", type=int, default=1)
    parser = log_parameter(config_data, parser)

    # parse arguments to trainer
    args = parser.parse_args()

    if args.gpus == 1:
        device = "cuda"
    elif args.gpus == 0:
        device = "cpu"

    # create mlflow experiment if it doesn't yet exist
    try:
        current_experiment = dict(mlflow.get_experiment_by_name(args.experiment_name))
        experiment_id = current_experiment["experiment_id"]
    except:
        print("creating new experiment")
        experiment_id = mlflow.create_experiment(args.experiment_name)

    # # start experiment
    with mlflow.start_run(experiment_id=experiment_id) as run:
        with open("log.txt", "a") as log_file:
            log_file.write("'" + str(run.info.run_id) + "'" + ", ")

        path_mlflow_results = (
            "mlruns/" + str(experiment_id) + "/" + str(run.info.run_id)
        )
        path_checkpoints = path_mlflow_results + "/checkpoints"

        # copy yaml file to mlfow results
        # TODO: this is a hack for now, this should automatically be logged
        # with open(path_mlflow_results + "/" + config_file, "w") as f:
        with open(path_mlflow_results + "/config.yaml", "w") as f:
            yaml.dump(config_data, f)

        # initialize dataloader
        config_data = initialize_datamodule(config_data)
        datamodule = config_data["datamodule"]
        # extract key for model selection
        loss_key = config_data["metric_model_selection"]
        if (
            config_data["datamodule"].split_label_val == "Barcode"
            and "val_" in loss_key[0]
        ):
            raise ValueError(
                "split_label_val=Barcode with metric_model_selection=",
                loss_key,
                " introduces data leakage",
            )

        # initialize lightning model
        if (
            config_data["lightning_model"]["class_path"]
            == "models.lightning_model.LightningModel"
        ):
            use_val_test_data_in_train = True
        elif (
            config_data["lightning_model"]["class_path"]
            == "models.lightning_model.LightningModelSimple"
        ):
            use_val_test_data_in_train = False
        config_data = initialize_modules(config_data)
        lightning_model = config_data["lightning_model"]
        print(type(lightning_model))
        print(type(datamodule))
        # monitor different metrics depending on loss variable
        checkpoints = []
        monitored_metrics = config_data["monitored_metrics"]
        for i, (me, mo) in enumerate(monitored_metrics):
            ckpt = pl.callbacks.ModelCheckpoint(
                monitor=me,
                mode=mo,
                dirpath=path_checkpoints,
                filename="{epoch:02d}-{" + me + ":.4f}",
                save_top_k=1,
            )
            checkpoints.append(ckpt)
        # checkpoints.append(
        #     pl.callbacks.ModelCheckpoint(
        #         dirpath=path_checkpoints,
        #         filename="every_n_{epoch:02d}",
        #         every_n_epochs=10,
        #         save_top_k=-1,  # <--- this is important!
        #     )
        # )

        # log all parameter
        mlflow.pytorch.autolog()
        for arg in vars(args):
            mlflow.log_param(arg, getattr(args, arg))

        # train model
        trainer = pl.Trainer(max_epochs=n_epochs, logger=True, callbacks=checkpoints)

        # TODO: this is very hackey and should be revisited
        # we create a combined dataloader which is the same for train/validation/test
        # batching is applied to the train dataloader, thus there will be multiple batches with the batch size defined in config.yaml
        # the validation and test datloaders only have one batch which has the size of the entire validation/test set
        # insight the lightning module we read out the validation and test batch at step 0 and save it as a class
        # attribute such that all validation and test data can be used in all training steps
        if use_val_test_data_in_train:
            datamodule.setup(stage="")
            iterables_train = {
                "train": datamodule.train_dataloader(),
                "val": datamodule.val_dataloader(),
                "test": datamodule.test_dataloader(),
            }
            iterables_val = {
                "train": datamodule.train_dataloader(),
                "val": datamodule.val_dataloader(),
                "test": datamodule.test_dataloader(),
            }
            iterables_test = {
                "train": datamodule.train_dataloader(),
                "val": datamodule.val_dataloader(),
                "test": datamodule.test_dataloader(),
            }
            combined_loader_train = CombinedLoader(iterables_train, mode="max_size")
            combined_loader_val = CombinedLoader(iterables_val, mode="max_size")
            combined_loader_test = CombinedLoader(iterables_test, mode="max_size")
            trainer.fit(lightning_model, combined_loader_train, combined_loader_val)
        else:
            trainer.fit(lightning_model, datamodule=datamodule)

        # evaluate tests for all monitored metrics
        ckpts = glob.glob(path_checkpoints + "/*")
        for ckpt in ckpts:
            if loss_key[0] in ckpt:
                if use_val_test_data_in_train:
                    result = trainer.test(
                        dataloaders=combined_loader_test, ckpt_path=ckpt
                    )
                else:
                    result = trainer.test(datamodule=datamodule, ckpt_path=ckpt)
                print(result)

Finally the main file

import torch
import utils.helpers as helpers

torch.random.manual_seed(123)

if __name__ == "__main__":
    # profil data
    # train_model("config_profil_latent.yaml")
    # train_model("config_profil_readout.yaml")
    # train_model("config_profil.yaml")
    # train_model("config_profil_simple.yaml")
    for weight_decay in [1.0]:
        for val_subject in range(0, 14):
            params = {
                "datamodule": {
                    "init_args": {
                        "val_index": [val_subject],
                        "test_index": [],
                    }
                },
                "lightning_model": {
                    "init_args": {
                        "weight_decay": weight_decay,
                    }
                },
            }
            helpers.train_model("config_profil_simple.yaml", params=params)

Error messages and logs

Traceback (most recent call last):
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 44, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 579, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 969, in _run
    _log_hyperparams(self)
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/pytorch_lightning/loggers/utilities.py", line 95, in _log_hyperparams
    logger.save()
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/lightning_utilities/core/rank_zero.py", line 42, in wrapped_fn
    return fn(*args, **kwargs)
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/lightning_fabric/loggers/csv_logs.py", line 157, in save
    self.experiment.save()
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/pytorch_lightning/loggers/csv_logs.py", line 67, in save
    save_hparams_to_yaml(hparams_file, self.hparams)
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/pytorch_lightning/core/saving.py", line 354, in save_hparams_to_yaml
    yaml.dump(v)
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/__init__.py", line 253, in dump
    return dump_all([data], stream, Dumper=Dumper, **kwds)
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/__init__.py", line 241, in dump_all
    dumper.represent(data)
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 27, in represent
    node = self.represent_data(data)
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 48, in represent_data
    node = self.yaml_representers[data_types[0]](self, data)
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 199, in represent_list
    return self.represent_sequence('tag:yaml.org,2002:seq', data)
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 92, in represent_sequence
    node_item = self.represent_data(item)
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 48, in represent_data
    node = self.yaml_representers[data_types[0]](self, data)
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 199, in represent_list
    return self.represent_sequence('tag:yaml.org,2002:seq', data)
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 92, in represent_sequence
    node_item = self.represent_data(item)
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 52, in represent_data
    node = self.yaml_multi_representers[data_type](self, data)
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 356, in represent_object
    return self.represent_mapping(tag+function_name, value)
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 118, in represent_mapping
    node_value = self.represent_data(item_value)
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 48, in represent_data
    node = self.yaml_representers[data_types[0]](self, data)
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 207, in represent_dict
    return self.represent_mapping('tag:yaml.org,2002:map', data)
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 118, in represent_mapping
    node_value = self.represent_data(item_value)
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 52, in represent_data
    node = self.yaml_multi_representers[data_type](self, data)
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 330, in represent_object
    dictitems = dict(dictitems)
ValueError: dictionary update sequence element #0 has length 1; 2 is required

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/pap_spiden_com/spiden_ds/experiments/artemis/main.py", line 28, in <module>
    helpers.train_model("config_profil_simple.yaml", params=params)
  File "/home/pap_spiden_com/spiden_ds/experiments/artemis/utils/helpers.py", line 191, in train_model
    trainer.fit(lightning_model, datamodule=datamodule)
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/mlflow/utils/autologging_utils/safety.py", line 573, in safe_patch_function
    patch_function(call_original, *args, **kwargs)
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/mlflow/utils/autologging_utils/safety.py", line 252, in patch_with_managed_run
    result = patch_function(original, *args, **kwargs)
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/mlflow/pytorch/_lightning_autolog.py", line 386, in patched_fit
    result = original(self, *args, **kwargs)
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/mlflow/utils/autologging_utils/safety.py", line 554, in call_original
    return call_original_fn_with_event_logging(_original_fn, og_args, og_kwargs)
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/mlflow/utils/autologging_utils/safety.py", line 489, in call_original_fn_with_event_logging
    original_fn_result = original_fn(*og_args, **og_kwargs)
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/mlflow/utils/autologging_utils/safety.py", line 551, in _original_fn
    original_result = original(*_og_args, **_og_kwargs)
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 543, in fit
    call._call_and_handle_interrupt(
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 67, in _call_and_handle_interrupt
    logger.finalize("failed")
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/lightning_utilities/core/rank_zero.py", line 42, in wrapped_fn
    return fn(*args, **kwargs)
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/lightning_fabric/loggers/csv_logs.py", line 166, in finalize
    self.save()
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/lightning_utilities/core/rank_zero.py", line 42, in wrapped_fn
    return fn(*args, **kwargs)
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/lightning_fabric/loggers/csv_logs.py", line 157, in save
    self.experiment.save()
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/pytorch_lightning/loggers/csv_logs.py", line 67, in save
    save_hparams_to_yaml(hparams_file, self.hparams)
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/pytorch_lightning/core/saving.py", line 354, in save_hparams_to_yaml
    yaml.dump(v)
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/__init__.py", line 253, in dump
    return dump_all([data], stream, Dumper=Dumper, **kwds)
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/__init__.py", line 241, in dump_all
    dumper.represent(data)
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 27, in represent
    node = self.represent_data(data)
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 48, in represent_data
    node = self.yaml_representers[data_types[0]](self, data)
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 199, in represent_list
    return self.represent_sequence('tag:yaml.org,2002:seq', data)
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 92, in represent_sequence
    node_item = self.represent_data(item)
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 48, in represent_data
    node = self.yaml_representers[data_types[0]](self, data)
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 199, in represent_list
    return self.represent_sequence('tag:yaml.org,2002:seq', data)
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 92, in represent_sequence
    node_item = self.represent_data(item)
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 52, in represent_data
    node = self.yaml_multi_representers[data_type](self, data)
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 356, in represent_object
    return self.represent_mapping(tag+function_name, value)
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 118, in represent_mapping
    node_value = self.represent_data(item_value)
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 48, in represent_data
    node = self.yaml_representers[data_types[0]](self, data)
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 207, in represent_dict
    return self.represent_mapping('tag:yaml.org,2002:map', data)
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 118, in represent_mapping
    node_value = self.represent_data(item_value)
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 52, in represent_data
    node = self.yaml_multi_representers[data_type](self, data)
  File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 330, in represent_object
    dictitems = dict(dictitems)
ValueError: dictionary update sequence element #0 has length 1; 2 is required

Environment

Current environment ``` #- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): #- PyTorch Lightning Version (e.g., 1.5.0): #- Lightning App Version (e.g., 0.5.2): #- PyTorch Version (e.g., 2.0): #- Python version (e.g., 3.9): #- OS (e.g., Linux): #- CUDA/cuDNN version: #- GPU models and configuration: #- How you installed Lightning(`conda`, `pip`, source): #- Running environment of LightningApp (e.g. local, cloud): ```

More info

No response

Hobbes-Le-Chat commented 4 months ago

Has this issue been solved? I so what is the solution? I have the same issue with rocm.

adosar commented 4 months ago

The same problem here also with torchmetrics==1.4.0.post0 and lightning==2.2.5.

radomirgr commented 4 months ago

Also have the same problem