Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
27.92k stars 3.34k forks source link

Limited types of data can be logged #8088

Closed borisdayma closed 2 years ago

borisdayma commented 3 years ago

🐛 Bug

Function __check_allowed limits the type of parameters that can be logged, even this data would be supported by loggers such as wandb.

To Reproduce

See following notebook.

Error stack:

ValueError                                Traceback (most recent call last)
<ipython-input-12-6d5bedbc661e> in <module>()
----> 1 trainer.fit(model, datamodule=mnist)

18 frames
/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloaders, val_dataloaders, datamodule, train_dataloader)
    509         self.checkpoint_connector.resume_start()
    510 
--> 511         self._run(model)
    512 
    513         assert self.state.stopped

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in _run(self, model)
    870 
    871         # dispatch `start_training` or `start_evaluating` or `start_predicting`
--> 872         self._dispatch()
    873 
    874         # plugin will finalized fitting (e.g. ddp_spawn will load trained model)

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in _dispatch(self)
    912             self.accelerator.start_predicting(self)
    913         else:
--> 914             self.accelerator.start_training(self)
    915 
    916     def run_stage(self):

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/accelerators/accelerator.py in start_training(self, trainer)
     94 
     95     def start_training(self, trainer: 'pl.Trainer') -> None:
---> 96         self.training_type_plugin.start_training(trainer)
     97 
     98     def start_evaluating(self, trainer: 'pl.Trainer') -> None:

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in start_training(self, trainer)
    163     def start_training(self, trainer: 'pl.Trainer') -> None:
    164         # double dispatch to initiate the training loop
--> 165         self._results = trainer.run_stage()
    166 
    167     def start_evaluating(self, trainer: 'pl.Trainer') -> None:

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in run_stage(self)
    922         if self.predicting:
    923             return self._run_predict()
--> 924         return self._run_train()
    925 
    926     def _pre_training_routine(self):

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in _run_train(self)
    956             self.progress_bar_callback.disable()
    957 
--> 958         self._run_sanity_check(self.lightning_module)
    959 
    960         self.checkpoint_connector.has_trained = False

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in _run_sanity_check(self, ref_model)
   1065             # run eval step
   1066             with torch.no_grad():
-> 1067                 self.evaluation_loop.run()
   1068 
   1069             self.on_sanity_check_end()

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/base.py in run(self, *args, **kwargs)
     89             try:
     90                 self.on_advance_start(*args, **kwargs)
---> 91                 self.advance(*args, **kwargs)
     92                 self.on_advance_end()
     93                 self.iteration_count += 1

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/dataloader/evaluation_dataloader_loop.py in advance(self, *args, **kwargs)
    118             self.current_dataloader_idx,
    119             dl_max_batches,
--> 120             self.num_dataloaders,
    121         )
    122 

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/base.py in run(self, *args, **kwargs)
     89             try:
     90                 self.on_advance_start(*args, **kwargs)
---> 91                 self.advance(*args, **kwargs)
     92                 self.on_advance_end()
     93                 self.iteration_count += 1

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/evaluation_epoch_loop.py in advance(self, dataloader_iter, dataloader_idx, dl_max_batches, num_dataloaders)
    110 
    111         # hook + store predictions
--> 112         self.on_evaluation_batch_end(output, batch, batch_idx, dataloader_idx)
    113 
    114         # log batch metrics

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/evaluation_epoch_loop.py in on_evaluation_batch_end(self, output, batch, batch_idx, dataloader_idx)
    194         """
    195         hook_name = "on_test_batch_end" if self.trainer.testing else "on_validation_batch_end"
--> 196         self.trainer.call_hook(hook_name, output, batch, batch_idx, dataloader_idx)
    197 
    198         self.trainer.logger_connector.on_batch_end()

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in call_hook(self, hook_name, *args, **kwargs)
   1165             if hasattr(self, hook_name):
   1166                 trainer_hook = getattr(self, hook_name)
-> 1167                 trainer_hook(*args, **kwargs)
   1168 
   1169             # next call hook in lightningModule

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/callback_hook.py in on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx)
    195         """Called when the validation batch ends."""
    196         for callback in self.callbacks:
--> 197             callback.on_validation_batch_end(self, self.lightning_module, outputs, batch, batch_idx, dataloader_idx)
    198 
    199     def on_test_batch_start(self, batch, batch_idx, dataloader_idx):

<ipython-input-8-bd547d9f98a3> in on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
     16             # we use `LightningModule.log` method for logging
     17             pl_module.log('examples', [wandb.Image(x_i, caption=f'Ground Truth: {y_i}\nPrediction: {y_pred}')
---> 18                                        for x_i, y_i, y_pred in list(zip(x[:n], y[:n], outputs[:n]))])

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/core/lightning.py in log(self, name, value, prog_bar, logger, on_step, on_epoch, reduce_fx, tbptt_reduce_fx, tbptt_pad_token, enable_graph, sync_dist, sync_dist_op, sync_dist_group, add_dataloader_idx, batch_size)
    336         apply_to_collection(value, dict, self.__check_not_nested, name)
    337         apply_to_collection(
--> 338             value, object, self.__check_allowed, name, value, wrong_dtype=(numbers.Number, Metric, Tensor, dict)
    339         )
    340 

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/utilities/apply_func.py in apply_to_collection(data, dtype, function, wrong_dtype, include_none, *args, **kwargs)
     94     # Breaking condition
     95     if isinstance(data, dtype) and (wrong_dtype is None or not isinstance(data, wrong_dtype)):
---> 96         return function(data, *args, **kwargs)
     97 
     98     elem_type = type(data)

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/core/lightning.py in __check_allowed(v, name, value)
    447     @staticmethod
    448     def __check_allowed(v: Any, name: str, value: Any) -> None:
--> 449         raise ValueError(f'`self.log({name}, {value})` was called, but `{type(v).__name__}` values cannot be logged')
    450 
    451     def __to_tensor(self, value: numbers.Number) -> torch.Tensor:

ValueError: `self.log(examples, [<wandb.sdk.data_types.Image object at 0x7f3a2ea28310>, <wandb.sdk.data_types.Image object at 0x7f3a2e1d6350>, <wandb.sdk.data_types.Image object at 0x7f3a2e1d6910>, <wandb.sdk.data_types.Image object at 0x7f3a2e1d6b50>, <wandb.sdk.data_types.Image object at 0x7f3a2e16f650>, <wandb.sdk.data_types.Image object at 0x7f3a2e189250>, <wandb.sdk.data_types.Image object at 0x7f3a2e1649d0>, <wandb.sdk.data_types.Image object at 0x7f3a2e17b790>, <wandb.sdk.data_types.Image object at 0x7f3a2e17f5d0>, <wandb.sdk.data_types.Image object at 0x7f3a2e16f610>, <wandb.sdk.data_types.Image object at 0x7f3a2e17f690>, <wandb.sdk.data_types.Image object at 0x7f3a2e17f750>, <wandb.sdk.data_types.Image object at 0x7f3a2e17f410>, <wandb.sdk.data_types.Image object at 0x7f3a2e17f8d0>, <wandb.sdk.data_types.Image object at 0x7f3a2e17f950>, <wandb.sdk.data_types.Image object at 0x7f3a2e17fa10>, <wandb.sdk.data_types.Image object at 0x7f3a2e17fad0>, <wandb.sdk.data_types.Image object at 0x7f3a2e17fb90>, <wandb.sdk.data_types.Image object at 0x7f3a2e17fc50>, <wandb.sdk.data_types.Image object at 0x7f3a2e17fd10>])` was called, but `list` values cannot be logged

Expected behavior

This data is supported by wandb logger which can log it (and display it) with no problem (and is used in such way by many users).

The validation data type should be associated only to loggers are some are more flexible than others.

Environment

Additional context

More and more data types are supported by wandb (3d graphs, audio, custom data types, etc) so it would be good to be able to take advantage of it.

A workaround would be to log without using LightningModule (and using directly wandb.log) but we lose some interesting lightning features (such as auto-tracking of global_step during logging) and also lose backward compatibility with existing code.

carmocca commented 3 years ago

Hi!

We recently redesigned the internal logging logic. Was this something supported before? Can you provide a minimal repro to play around and test it?

self.log is designed to support numbers, tensors, or dictionaries of the former. To support this we would either need to provide another function for this (e.g. self.log_data) or add a flag to self.log so most of the logging logic is skipped (e.g reductions and aggregations).

borisdayma commented 3 years ago

Hi, yes there is a colab in the issue under "To Reproduce" section.

tchaton commented 3 years ago

Hey @carmocca , @borisdayma,

I don't believe we ever supported logging images directly with the self.log function. I think having self.log_image or derivate would be the cleanest way forward.

Best, T.C

borisdayma commented 3 years ago

It works with latest released version on the exact same notebook: run

tchaton commented 3 years ago

Hey @borisdayma,

Interesting, I don't understand how this could have worked :) Need to investigate.

Best, T.C

ananthsub commented 3 years ago

@borisdayma as a workaround, in the callback could you directly log with trainer.logger.experiment.<wandb_api> ?

borisdayma commented 3 years ago

Yes this would work to access directly the correct logger.

The main issue I see is that several people use the self.log() pattern which technically is just a shortcut for self.logger.experiment.log() with WandbLogger.

It has also other advantages as it works when there is no logger attached to the trainer yet (self.log() just becomes a noop) for example when using the learning rate finder prior to training.

borisdayma commented 3 years ago

Here is a cool notebook which was working a short time ago on master and that I planned to showcase for next release of pytorch-lightning master (I guess v1.4).

borisdayma commented 3 years ago

Just to let you know, for now I replaced the pl_module.log(examples, [wandb.Image(…)]) with wandb.log({'examples': [wandb.Image(…)}) in my colab.

FrancescoSaverioZuppichini commented 2 years ago

Hi guys, is there a way to log a Confusion Matrix using the build-in .log function? Since you have to pass a reduce function and you can't specify the axis the resulting logged metric will always be a scalar.

tchaton commented 2 years ago

Hey @borisdayma,

Thanks, I bumped this one to P0 as it was definitely a great feature. We need to investigate how to properly support this again.

Best, T.C

borisdayma commented 2 years ago

Cool, let me know how I can help!

tchaton commented 2 years ago

Hey @borisdayma,

After brainstorming with the Lightning Team, we decided to close this issue for the following reasons:

This change has already in place for 2 major releases, so the log API supporting tensor, metrics is considered as a stable API now.

A possible idea would be for the LightningModule to support log_{}, but we don't believe there is a way to support all the underlying data API from the loggers in a consistent way + would add maintenance cost on the Lightning Team.

Best, T.C

borisdayma commented 2 years ago

Thanks for the feedback. Ping me back if you want to support it at some point.