Open bw4sz opened 11 months ago
restore _current_fx_name
might work:
def on_validation_epoch_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:
trainer_state = deepcopy(trainer.state)
current_fx_name = pl_module._current_fx_name
results = trainer.predict(pl_module, dataloader, return_predictions=True)
trainer.state = trainer_state
pl_module._current_fx_name = current_fx_name
pl_module.log("val/ex", 0, prog_bar=True)
Thanks, can you give any insight into why that works, what's happening here that allows the trainer to be used inside the hook?
Thanks, can you give any insight into why that works, what's happening here that allows the trainer to be used inside the hook?
The on_validation_epoch_end
of Callback
accept trainer
and pl_module
as parameters.
For LightningModule
, you can just access trainer
and _current_fx_name
through self.trainer
and self._current_fx_name
such as:
def on_validation_epoch_end(self):
if self.trainer.sanity_checking: # optional skip
return
trainer_state = deepcopy(self.trainer.state)
current_fx_name = self._current_fx_name
print("Start predicting!")
dataloader = self.predict_dataloader()
self.trainer.predict(self, dataloaders=dataloader)
self.trainer.state = trainer_state
self._current_fx_name = current_fx_name
self.log("metric", 1.0)
As for why that works, this is because we reset the _current_fx_name
of LightningModule
changed after prediction which cause self.log
not working.
Bug description
There has been alot of discussion around logging, trainer.predict, evaluation hooks and callbacks. I think I can boil this down to a reproducible example that will be useful for the community.
What has been discussed so far.
https://github.com/Lightning-AI/pytorch-lightning/issues/10365 https://github.com/Lightning-AI/pytorch-lightning/discussions/16258 (where I started the example below) https://github.com/Lightning-AI/pytorch-lightning/issues/16822 https://github.com/Lightning-AI/pytorch-lightning/issues/7333
From these links, there is no clear guidance between using trainer.predict_step() and trainer.predict in why one can use logging and the other cannot. This is flirting with being a bug, but appears to be intended behavior from the comment below.
We are not inside a predict hook, we are inside a evaluation_hook. We did use trainer.predict, with all of its great functionality, to generate a set of predictions.
Expected behavior
I understand from the above issues as stated by @carmocca (https://github.com/Lightning-AI/pytorch-lightning/issues/7333#issuecomment-1027107255) that we cannot overwrite the trainer state. Why doesn't this work with a new trainer?
What version are you seeing the problem on?
v2.1
How to reproduce the bug
Error messages and logs
Environment
More info
No response