pytorch / torcheval

A library that contains a rich collection of performant PyTorch model metrics, a simple interface to create new metrics, a toolkit to facilitate metric computation in distributed training and tools for PyTorch model evaluations.
https://pytorch.org/torcheval
Other
211 stars 46 forks source link

RetrievalRecall, RetrievalPrecision require different, 1D input than MulticlassRecall, MulticlassPrecision which accept batch input #188

Open jaanli opened 8 months ago

jaanli commented 8 months ago

🐛 Describe the bug

The different behavior of RetrievalRecall and RetrievalPrecision make it difficult to compute standard metrics such as Precision@k or Recall@k for multiclass classification problems.

Would it be possible to have them accept the same shape of input, e.g. inputs of shape batch_size, num_classes and targets of shape batch_size, num_classes?

Example code below:

To install: pip install --pre torcheval-nightly; using '0.0.7'.

import torch
from torch.nn import functional as F
from torcheval.metrics import RetrievalRecall

batch_size = 10
num_classes = 20
# generate random predictions
preds = torch.rand(batch_size, num_classes)
# generate random targets
targets = torch.randint(0, num_classes, (batch_size,))

recall = RetrievalRecall(num_queries=batch_size, k=5)

# first make the targets one hot (RetrievalRecall does not accept num_classes arguments, requires binary targets)
targets_one_hot = F.one_hot(targets.type(torch.long), num_classes)
targets_one_hot.shape

# indexes associate each prediction with a target
indexes = torch.arange(batch_size).repeat(num_classes, 1).T

recall.update(preds.ravel(), targets_one_hot.ravel(), indexes=indexes.ravel())

recall.compute().mean() # -> 0.1

from torcheval.metrics import MulticlassRecall, MulticlassPrecision

recall = MulticlassRecall(num_classes=num_classes)
precision = MulticlassPrecision(num_classes=num_classes)
recall.update(preds, targets)
precision.update(preds, targets)
recall.compute(), precision.compute() # -> 0.1, 0.1

Current workaround:

import torch
from torch.nn import functional as F
from torcheval.metrics import RetrievalRecall

class MulticlassRetrievalRecall(RetrievalRecall):
    def __init__(self, batch_size, num_classes, **kwargs):
        super().__init__(num_queries=batch_size, **kwargs)
        self.num_classes = num_classes

    def update(self, input, target):
        target_one_hot = F.one_hot(target.type(torch.long), self.num_classes)
        indexes = torch.arange(len(input)).repeat(self.num_classes, 1).T
        super().update(input.ravel(), target_one_hot.ravel(), indexes=indexes.ravel())

Usage:

recall_multi = MulticlassRetrievalRecall(batch_size, num_classes, k=5)
recall_multi.update(preds, targets)
recall_multi.compute().mean() # -> 0.1

Open to any tips on how best to do this! Thank for this helpful canonical library :)

Versions

python collect_env.py                                                                                       9854  17:14:34  

Collecting environment information...
PyTorch version: 2.1.1
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 13.6.2 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.0.40.1)
CMake version: version 3.22.2
Libc version: N/A

Python version: 3.11.6 (main, Nov  2 2023, 04:39:43) [Clang 14.0.3 (clang-1403.0.22.14.1)] (64-bit runtime)
Python platform: macOS-13.6.2-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M1 Max

Versions of relevant libraries:
[pip3] numpy==1.26.2
[pip3] torch==2.1.1
[pip3] torchaudio==2.1.1
[pip3] torchdata==0.7.1
[pip3] torcheval==0.0.7
[pip3] torcheval-nightly==2023.12.21
[pip3] torchtext==0.16.1
[pip3] torchvision==0.16.1
[conda] numpy                     1.24.3          py310hb93e574_0  
[conda] numpy-base                1.24.3          py310haf87e8b_0  
[conda] torch                     2.0.1                    pypi_0    pypi
jaanli commented 8 months ago

cc @jsseely for visibility

bobakfb commented 8 months ago

@galrotem @JKSenthil Any chance one of you could look into this and the other posted by @jaanli