Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
27.92k stars 3.34k forks source link

Metric on all test data #2809

Closed celsofranssa closed 3 years ago

celsofranssa commented 4 years ago

Is there an approach to handle scenarios in which the metric calculated during test_step depends on the entire test set and not just the existing data in the batch?

awaelchli commented 4 years ago

Hi, do you mean you want to have all outputs from your model in test_epoch_end, and then compute a metric over all of them? You can do that by returning your predictions in the output dict of the test_step method and in test_epoch_end you will get a list of all of them. Would that work in your usecase?

celsofranssa commented 4 years ago

I am currently working on an information retrieval model that encodes both query and the document into dense vectors of size equal to 768. Then, in test_step I would like to calculate metrics like MRR (Mean Reciprocal Rank) to the entire test set.

Save test_step outputs will work. However, as my model outputs embedding the best practice maybe is to save the model and use it in another test environment.

SkafteNicki commented 3 years ago

@Ceceu Class based metrics have been revamped! Your case is very similar to ExplainedVariance metric (https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/metrics/regression/explained_variance.py), as this metric also requires access to all predictions and targets to compute the metric values in the end. Please, check this out to see if this could solve your problem.

celsofranssa commented 3 years ago

@Ceceu Class based metrics have been revamped! Your case is very similar to ExplainedVariance metric (https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/metrics/regression/explained_variance.py), as this metric also requires access to all predictions and targets to compute the metric values in the end. Please, check this out to see if this could solve your problem.

Thanks @SkafteNicki,

I'll check it. Currently, I'm saving the predictions into a file (using EvalResult) and the metric is computed after the train step.

celsofranssa commented 3 years ago

@Ceceu Class based metrics have been revamped! Your case is very similar to ExplainedVariance metric (https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/metrics/regression/explained_variance.py), as this metric also requires access to all predictions and targets to compute the metric values in the end. Please, check this out to see if this could solve your problem.

@SkafteNicki,

Where is stored the state variable? Can't that lead to an overflow of memory?

SkafteNicki commented 3 years ago

It is stored on the same device as your model. You are completely right that this can lead to out of memory errors, however to my understanding it is impossible to compute such metrics without having access to all predictions and targets. This is the very reason for this warning: https://github.com/PyTorchLightning/pytorch-lightning/blob/89e8796e2a14429541f923b2ecf8fd7079c32a65/pytorch_lightning/metrics/regression/explained_variance.py#L94-L98

celsofranssa commented 3 years ago

It is stored on the same device as your model. You are completely right that this can lead to out of memory errors, however to my understanding it is impossible to compute such metrics without having access to all predictions and targets.

Would it be possible to calculate a metric from predictions saved in disk using one of the methods below?

https://github.com/PyTorchLightning/pytorch-lightning/blob/471ca375babad9093abf60683a8d0647ac33d4a8/pytorch_lightning/core/lightning.py#L347-L352

If so, someone could use test_step to save predictions_dict and in test_epoch_end to calculates the metric. What do you think about it?

SkafteNicki commented 3 years ago

I definitely think that would work. It would probably be a good enhancement in the future to allow metrics a generic way of saving to disk, and automatically loading during the .compute() step.