I am trying to train a Lightning model that inherits from pl.LightningModule and implements a simple feed-forward network. The issue is that when I run it, it spits out the below error trace coming from trainer.fit(). I found this very similar issue, where downgrading to torchmetrics<=0.5.0 fixed the issue, but that is not possible in my case as v2.2.0 of pytorch-lightning is not compatible with such an old version of torchmetrics. I tried downgrading to 0.7., the oldest compatible version, but it led to a different error also in the trainer.fit method.
Thanks for your attention and I would appreciate any help with this.
What version are you seeing the problem on?
v2.2
How to reproduce the bug
Below is the model class definition
import pytorch_lightning as pl
import torch
import numpy as np
from torch.nn import MSELoss, L1Loss
from torchmetrics import R2Score
torch.random.manual_seed(123)
class LightningModelSimple(pl.LightningModule):
def __init__(
self,
latent_model,
readout_model=None,
losses={},
metrics=[],
gpu=True,
learning_rate=0.001,
weight_decay=0.0,
):
super().__init__()
self.save_hyperparameters()
self.latent_model = latent_model
if readout_model is None:
self.readout_model = torch.nn.Identity()
else:
self.readout_model = readout_model
# losses
if "target" in losses:
self.loss_target = losses["target"]
else:
self.loss_target = None
if "latent_target" in losses:
self.loss_latent_target = losses["latent_target"]
self.weight_loss_latent_target = losses["weight_loss_latent_target"]
else:
self.loss_latent_target = None
self.gpu = gpu
self.metrics = metrics
self.learning_rate = learning_rate
self.weight_decay = weight_decay
def forward(self, x):
x_latent = self.latent_model(x)
y = self.readout_model(x_latent)
return y
def step(self, partition, batch, batch_idx):
spectra, target_glucose = batch
# get latent predictions
self.pred_latent = self.latent_model(spectra.float())
# get glucose predictions
self.pred_glucose = self.readout_model(self.pred_latent)
# compute losses
loss = 0
if self.loss_target is not None:
loss += self.loss_target(self.pred_glucose, target_glucose)
self.log(partition + "_loss_target", loss, on_epoch=True)
if self.loss_latent_target is not None:
loss_latent_target = (
self.weight_loss_latent_target
* self.loss_latent_target(self.pred_latent, target_glucose.unsqueeze(1))
)
self.log(
partition + "_loss_latent_target", loss_latent_target, on_epoch=True
)
loss += loss_latent_target
self.log(partition + "_loss_total", loss, on_epoch=True)
for metric_name, metric in self.metrics:
self.log(
partition + "_" + metric_name,
metric(self.pred_glucose, target_glucose),
on_epoch=True,
)
return loss
def training_step(self, batch, batch_idx):
return self.step("train", batch, batch_idx)
def validation_step(self, batch, batch_idx):
return self.step("val", batch, batch_idx)
def test_step(self, batch, batch_idx):
return self.step("test", batch, batch_idx)
def configure_optimizers(self):
return torch.optim.Adam(
self.parameters(),
lr=self.hparams.learning_rate,
weight_decay=self.hparams.weight_decay,
)
This should go in a different file called helpers.py
def log_parameter(params, parser, param_name=""):
if isinstance(params, dict):
for key in params.keys():
if key == "class_path":
parser = log_parameter(params[key], parser, param_name)
else:
parser = log_parameter(params[key], parser, key)
else:
parser.add_argument("--" + param_name, type=type(params), default=params)
return parser
def update(config_data, params):
for k, v in params.items():
if isinstance(v, collections.abc.Mapping):
config_data[k] = update(config_data.get(k, {}), v)
else:
config_data[k] = v
return config_data
def train_model(config_file, **kwargs):
loader = yaml.SafeLoader
with open(config_file, "r") as stream:
config_data = yaml.load(stream, Loader=loader)
if "params" in kwargs:
config_data = update(config_data, kwargs["params"])
if "latent_model" in kwargs:
config_data["lightning_model"]["init_args"]["latent_model"] = kwargs[
"latent_model"
]
# experiment_name = config_data["experiment_name"]
n_epochs = config_data["n_epochs"]
pl.seed_everything(1234)
# add arguments to parser
parser = ArgumentParser(conflict_handler="resolve")
parser.add_argument(
"--auto-select-gpus", default=True, help="run automatically on GPU if available"
)
parser.add_argument("--max-epochs", default=n_epochs, type=int)
parser.add_argument("gpus", type=int, default=1)
parser = log_parameter(config_data, parser)
# parse arguments to trainer
args = parser.parse_args()
if args.gpus == 1:
device = "cuda"
elif args.gpus == 0:
device = "cpu"
# create mlflow experiment if it doesn't yet exist
try:
current_experiment = dict(mlflow.get_experiment_by_name(args.experiment_name))
experiment_id = current_experiment["experiment_id"]
except:
print("creating new experiment")
experiment_id = mlflow.create_experiment(args.experiment_name)
# # start experiment
with mlflow.start_run(experiment_id=experiment_id) as run:
with open("log.txt", "a") as log_file:
log_file.write("'" + str(run.info.run_id) + "'" + ", ")
path_mlflow_results = (
"mlruns/" + str(experiment_id) + "/" + str(run.info.run_id)
)
path_checkpoints = path_mlflow_results + "/checkpoints"
# copy yaml file to mlfow results
# TODO: this is a hack for now, this should automatically be logged
# with open(path_mlflow_results + "/" + config_file, "w") as f:
with open(path_mlflow_results + "/config.yaml", "w") as f:
yaml.dump(config_data, f)
# initialize dataloader
config_data = initialize_datamodule(config_data)
datamodule = config_data["datamodule"]
# extract key for model selection
loss_key = config_data["metric_model_selection"]
if (
config_data["datamodule"].split_label_val == "Barcode"
and "val_" in loss_key[0]
):
raise ValueError(
"split_label_val=Barcode with metric_model_selection=",
loss_key,
" introduces data leakage",
)
# initialize lightning model
if (
config_data["lightning_model"]["class_path"]
== "models.lightning_model.LightningModel"
):
use_val_test_data_in_train = True
elif (
config_data["lightning_model"]["class_path"]
== "models.lightning_model.LightningModelSimple"
):
use_val_test_data_in_train = False
config_data = initialize_modules(config_data)
lightning_model = config_data["lightning_model"]
print(type(lightning_model))
print(type(datamodule))
# monitor different metrics depending on loss variable
checkpoints = []
monitored_metrics = config_data["monitored_metrics"]
for i, (me, mo) in enumerate(monitored_metrics):
ckpt = pl.callbacks.ModelCheckpoint(
monitor=me,
mode=mo,
dirpath=path_checkpoints,
filename="{epoch:02d}-{" + me + ":.4f}",
save_top_k=1,
)
checkpoints.append(ckpt)
# checkpoints.append(
# pl.callbacks.ModelCheckpoint(
# dirpath=path_checkpoints,
# filename="every_n_{epoch:02d}",
# every_n_epochs=10,
# save_top_k=-1, # <--- this is important!
# )
# )
# log all parameter
mlflow.pytorch.autolog()
for arg in vars(args):
mlflow.log_param(arg, getattr(args, arg))
# train model
trainer = pl.Trainer(max_epochs=n_epochs, logger=True, callbacks=checkpoints)
# TODO: this is very hackey and should be revisited
# we create a combined dataloader which is the same for train/validation/test
# batching is applied to the train dataloader, thus there will be multiple batches with the batch size defined in config.yaml
# the validation and test datloaders only have one batch which has the size of the entire validation/test set
# insight the lightning module we read out the validation and test batch at step 0 and save it as a class
# attribute such that all validation and test data can be used in all training steps
if use_val_test_data_in_train:
datamodule.setup(stage="")
iterables_train = {
"train": datamodule.train_dataloader(),
"val": datamodule.val_dataloader(),
"test": datamodule.test_dataloader(),
}
iterables_val = {
"train": datamodule.train_dataloader(),
"val": datamodule.val_dataloader(),
"test": datamodule.test_dataloader(),
}
iterables_test = {
"train": datamodule.train_dataloader(),
"val": datamodule.val_dataloader(),
"test": datamodule.test_dataloader(),
}
combined_loader_train = CombinedLoader(iterables_train, mode="max_size")
combined_loader_val = CombinedLoader(iterables_val, mode="max_size")
combined_loader_test = CombinedLoader(iterables_test, mode="max_size")
trainer.fit(lightning_model, combined_loader_train, combined_loader_val)
else:
trainer.fit(lightning_model, datamodule=datamodule)
# evaluate tests for all monitored metrics
ckpts = glob.glob(path_checkpoints + "/*")
for ckpt in ckpts:
if loss_key[0] in ckpt:
if use_val_test_data_in_train:
result = trainer.test(
dataloaders=combined_loader_test, ckpt_path=ckpt
)
else:
result = trainer.test(datamodule=datamodule, ckpt_path=ckpt)
print(result)
Finally the main file
import torch
import utils.helpers as helpers
torch.random.manual_seed(123)
if __name__ == "__main__":
# profil data
# train_model("config_profil_latent.yaml")
# train_model("config_profil_readout.yaml")
# train_model("config_profil.yaml")
# train_model("config_profil_simple.yaml")
for weight_decay in [1.0]:
for val_subject in range(0, 14):
params = {
"datamodule": {
"init_args": {
"val_index": [val_subject],
"test_index": [],
}
},
"lightning_model": {
"init_args": {
"weight_decay": weight_decay,
}
},
}
helpers.train_model("config_profil_simple.yaml", params=params)
Error messages and logs
Traceback (most recent call last):
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 44, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 579, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 969, in _run
_log_hyperparams(self)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/pytorch_lightning/loggers/utilities.py", line 95, in _log_hyperparams
logger.save()
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/lightning_utilities/core/rank_zero.py", line 42, in wrapped_fn
return fn(*args, **kwargs)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/lightning_fabric/loggers/csv_logs.py", line 157, in save
self.experiment.save()
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/pytorch_lightning/loggers/csv_logs.py", line 67, in save
save_hparams_to_yaml(hparams_file, self.hparams)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/pytorch_lightning/core/saving.py", line 354, in save_hparams_to_yaml
yaml.dump(v)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/__init__.py", line 253, in dump
return dump_all([data], stream, Dumper=Dumper, **kwds)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/__init__.py", line 241, in dump_all
dumper.represent(data)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 27, in represent
node = self.represent_data(data)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 48, in represent_data
node = self.yaml_representers[data_types[0]](self, data)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 199, in represent_list
return self.represent_sequence('tag:yaml.org,2002:seq', data)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 92, in represent_sequence
node_item = self.represent_data(item)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 48, in represent_data
node = self.yaml_representers[data_types[0]](self, data)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 199, in represent_list
return self.represent_sequence('tag:yaml.org,2002:seq', data)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 92, in represent_sequence
node_item = self.represent_data(item)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 52, in represent_data
node = self.yaml_multi_representers[data_type](self, data)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 356, in represent_object
return self.represent_mapping(tag+function_name, value)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 118, in represent_mapping
node_value = self.represent_data(item_value)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 48, in represent_data
node = self.yaml_representers[data_types[0]](self, data)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 207, in represent_dict
return self.represent_mapping('tag:yaml.org,2002:map', data)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 118, in represent_mapping
node_value = self.represent_data(item_value)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 52, in represent_data
node = self.yaml_multi_representers[data_type](self, data)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 330, in represent_object
dictitems = dict(dictitems)
ValueError: dictionary update sequence element #0 has length 1; 2 is required
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/pap_spiden_com/spiden_ds/experiments/artemis/main.py", line 28, in <module>
helpers.train_model("config_profil_simple.yaml", params=params)
File "/home/pap_spiden_com/spiden_ds/experiments/artemis/utils/helpers.py", line 191, in train_model
trainer.fit(lightning_model, datamodule=datamodule)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/mlflow/utils/autologging_utils/safety.py", line 573, in safe_patch_function
patch_function(call_original, *args, **kwargs)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/mlflow/utils/autologging_utils/safety.py", line 252, in patch_with_managed_run
result = patch_function(original, *args, **kwargs)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/mlflow/pytorch/_lightning_autolog.py", line 386, in patched_fit
result = original(self, *args, **kwargs)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/mlflow/utils/autologging_utils/safety.py", line 554, in call_original
return call_original_fn_with_event_logging(_original_fn, og_args, og_kwargs)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/mlflow/utils/autologging_utils/safety.py", line 489, in call_original_fn_with_event_logging
original_fn_result = original_fn(*og_args, **og_kwargs)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/mlflow/utils/autologging_utils/safety.py", line 551, in _original_fn
original_result = original(*_og_args, **_og_kwargs)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 543, in fit
call._call_and_handle_interrupt(
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 67, in _call_and_handle_interrupt
logger.finalize("failed")
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/lightning_utilities/core/rank_zero.py", line 42, in wrapped_fn
return fn(*args, **kwargs)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/lightning_fabric/loggers/csv_logs.py", line 166, in finalize
self.save()
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/lightning_utilities/core/rank_zero.py", line 42, in wrapped_fn
return fn(*args, **kwargs)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/lightning_fabric/loggers/csv_logs.py", line 157, in save
self.experiment.save()
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/pytorch_lightning/loggers/csv_logs.py", line 67, in save
save_hparams_to_yaml(hparams_file, self.hparams)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/pytorch_lightning/core/saving.py", line 354, in save_hparams_to_yaml
yaml.dump(v)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/__init__.py", line 253, in dump
return dump_all([data], stream, Dumper=Dumper, **kwds)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/__init__.py", line 241, in dump_all
dumper.represent(data)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 27, in represent
node = self.represent_data(data)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 48, in represent_data
node = self.yaml_representers[data_types[0]](self, data)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 199, in represent_list
return self.represent_sequence('tag:yaml.org,2002:seq', data)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 92, in represent_sequence
node_item = self.represent_data(item)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 48, in represent_data
node = self.yaml_representers[data_types[0]](self, data)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 199, in represent_list
return self.represent_sequence('tag:yaml.org,2002:seq', data)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 92, in represent_sequence
node_item = self.represent_data(item)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 52, in represent_data
node = self.yaml_multi_representers[data_type](self, data)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 356, in represent_object
return self.represent_mapping(tag+function_name, value)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 118, in represent_mapping
node_value = self.represent_data(item_value)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 48, in represent_data
node = self.yaml_representers[data_types[0]](self, data)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 207, in represent_dict
return self.represent_mapping('tag:yaml.org,2002:map', data)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 118, in represent_mapping
node_value = self.represent_data(item_value)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 52, in represent_data
node = self.yaml_multi_representers[data_type](self, data)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 330, in represent_object
dictitems = dict(dictitems)
ValueError: dictionary update sequence element #0 has length 1; 2 is required
Environment
Current environment
```
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):
```
Bug description
I am trying to train a Lightning model that inherits from pl.LightningModule and implements a simple feed-forward network. The issue is that when I run it, it spits out the below error trace coming from trainer.fit(). I found this very similar issue, where downgrading to
torchmetrics<=0.5.0
fixed the issue, but that is not possible in my case as v2.2.0 of pytorch-lightning is not compatible with such an old version of torchmetrics. I tried downgrading to 0.7., the oldest compatible version, but it led to a different error also in the trainer.fit method.Thanks for your attention and I would appreciate any help with this.
What version are you seeing the problem on?
v2.2
How to reproduce the bug
Error messages and logs
Environment
Current environment
``` #- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): #- PyTorch Lightning Version (e.g., 1.5.0): #- Lightning App Version (e.g., 0.5.2): #- PyTorch Version (e.g., 2.0): #- Python version (e.g., 3.9): #- OS (e.g., Linux): #- CUDA/cuDNN version: #- GPU models and configuration: #- How you installed Lightning(`conda`, `pip`, source): #- Running environment of LightningApp (e.g. local, cloud): ```More info
No response