Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.51k stars 3.39k forks source link

Add `predict_epoch_end` hook. #9380

Closed rohitgr7 closed 1 year ago

rohitgr7 commented 3 years ago

🚀 Feature

Motivation

Motivation: https://github.com/PyTorchLightning/pytorch-lightning/discussions/9379 Also, I remember it's a TODO somewhere.

Pitch

The hook will be similar to {val/test}_epoch_end but it will return the outputs. Also, should we update the signature of on_predict_epoch_end to not accept the outputs, since they don't actually return anything so even if someone wants to modify the predictions, it won't do have any effect on the original predictions.

Alternatives

Can't think of any.

Additional context


If you enjoy Lightning, check out our other projects! âš¡

ananthsub commented 3 years ago

There are a number of issues with prediction right now that are at least blocking FB's usage of Trainer.predict

  1. Inconsistent API around outputs, as mentioned here: https://github.com/PyTorchLightning/pytorch-lightning/issues/8479
  2. Predictions are by default stored & returned: https://github.com/PyTorchLightning/pytorch-lightning/blob/a079d7fccc0a9be25b40296f2a348c4b4f40c8cf/pytorch_lightning/trainer/trainer.py#L793-L794 . API wise this is inconsistent with validate and test. More critically for us, this risks OOMs for large-scale prediction unless users are careful to disable this flag.
  3. The Trainer is currently inconsistent around checks for batch samplers. This first checks if the dataloader has a batch sampler before applying the wrapper for prediction. https://github.com/PyTorchLightning/pytorch-lightning/blob/8407238d66df14c4476f880a6e6260b4bfa83b40/pytorch_lightning/trainer/data_loading.py#L161-L171

But this unconditionally accesses the attribute: https://github.com/PyTorchLightning/pytorch-lightning/blob/8407238d66df14c4476f880a6e6260b4bfa83b40/pytorch_lightning/loops/epoch/prediction_epoch_loop.py#L163-L164

a quickfix could be to check if the dataloader has a batch sampler here: https://github.com/PyTorchLightning/pytorch-lightning/blob/8407238d66df14c4476f880a6e6260b4bfa83b40/pytorch_lightning/loops/epoch/prediction_epoch_loop.py#L163-L164

  1. RE: epoch end hooks, https://github.com/PyTorchLightning/pytorch-lightning/issues/8731 has more discussion on this. I personally think we should not be adding these hooks and instead ask users to either store what is currently returned from predict_step inside the lightning module, or have callbacks do the post-processing in on_predict_bach_end . storing data in the trainer doesn't directly use can quickly lead to bugs (if we do some post-processing wrong) or performance slowdowns (if we use more memory than we need to)

cc @tchaton

rohitgr7 commented 3 years ago

accumulating predictions is pretty much just some boilerplate code in usual cases, and if lightning can provide it on the fly, then I think predict_epoch_end is a useful hook to have. Atleast, there should be some default structure defined within that users can rely on without writing the same duplicate logic since accumulating predictions with different dataloaders isn't that trivial for everyone (atleast for starters). This is ofcourse optional though. If users want they can write their own logic too and disable return_predictions.

rohitgr7 commented 3 years ago

+1 for deprecating outputs from on_predict_epoch_end for consistency with other hooks if we implement predict_epoch_end.

tchaton commented 3 years ago
  1. I believe this is reasonable as you expect predictions to be returned when performing predict and there is simple way to opt-out. We would add a warning to inform the users it might cause OOM as predictions are stored and advise for BasePredictionWriter alternative.

  2. Good catch !

  3. I believe this is a question of simplification of accessibility vs engineering simplification. I think it is intuitive for the predictions to be saved, but we might want to re-think the API for real world use-case.

m13uz commented 2 years ago

So what is the status of this feature?

dagap commented 2 years ago

+1 on predict_epoch_end which allows one to modify the outputs.

CompRhys commented 1 year ago

I would also benefit from this feature!

sofroniewn commented 1 year ago

I would also benefit from this feature, curious if there is any update on plans here. Or I'm curious if there are any alternatives/ best practices that people have adopted that I could learn from. My use case is the same as in #9379.

I could imagine using on_predict_epoch_end to do my post-processing and store the results on my LightningModule, but I don't like that as much as I quite like the syntax of predictions = trainer.predict(lm, dm) and find it a bit strange to store data on the LightningModule itself.

Thanks!!