pytorch / torcheval

A library that contains a rich collection of performant PyTorch model metrics, a simple interface to create new metrics, a toolkit to facilitate metric computation in distributed training and tools for PyTorch model evaluations.
211 stars 46 forks source link

RetrievalRecall, RetrievalPrecision require different, 1D input than MulticlassRecall, MulticlassPrecision which accept batch input #188

Open jaanli opened 8 months ago

jaanli commented 8 months ago

🐛 Describe the bug

The different behavior of RetrievalRecall and RetrievalPrecision make it difficult to compute standard metrics such as Precision@k or Recall@k for multiclass classification problems.

Would it be possible to have them accept the same shape of input, e.g. inputs of shape batch_size, num_classes and targets of shape batch_size, num_classes?

Example code below:

To install: pip install --pre torcheval-nightly; using '0.0.7'.

import torch
from torch.nn import functional as F
from torcheval.metrics import RetrievalRecall

batch_size = 10
num_classes = 20
# generate random predictions
preds = torch.rand(batch_size, num_classes)
# generate random targets
targets = torch.randint(0, num_classes, (batch_size,))

recall = RetrievalRecall(num_queries=batch_size, k=5)

# first make the targets one hot (RetrievalRecall does not accept num_classes arguments, requires binary targets)
targets_one_hot = F.one_hot(targets.type(torch.long), num_classes)

# indexes associate each prediction with a target
indexes = torch.arange(batch_size).repeat(num_classes, 1).T

recall.update(preds.ravel(), targets_one_hot.ravel(), indexes=indexes.ravel())

recall.compute().mean() # -> 0.1

from torcheval.metrics import MulticlassRecall, MulticlassPrecision

recall = MulticlassRecall(num_classes=num_classes)
precision = MulticlassPrecision(num_classes=num_classes)
recall.update(preds, targets)
precision.update(preds, targets)
recall.compute(), precision.compute() # -> 0.1, 0.1

Current workaround:

import torch
from torch.nn import functional as F
from torcheval.metrics import RetrievalRecall

class MulticlassRetrievalRecall(RetrievalRecall):
    def __init__(self, batch_size, num_classes, **kwargs):
        super().__init__(num_queries=batch_size, **kwargs)
        self.num_classes = num_classes

    def update(self, input, target):
        target_one_hot = F.one_hot(target.type(torch.long), self.num_classes)
        indexes = torch.arange(len(input)).repeat(self.num_classes, 1).T
        super().update(input.ravel(), target_one_hot.ravel(), indexes=indexes.ravel())


recall_multi = MulticlassRetrievalRecall(batch_size, num_classes, k=5)
recall_multi.update(preds, targets)
recall_multi.compute().mean() # -> 0.1

Open to any tips on how best to do this! Thank for this helpful canonical library :)


python                                                                                       9854  17:14:34  

Collecting environment information...
PyTorch version: 2.1.1
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 13.6.2 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.0.40.1)
CMake version: version 3.22.2
Libc version: N/A

Python version: 3.11.6 (main, Nov  2 2023, 04:39:43) [Clang 14.0.3 (clang-1403.] (64-bit runtime)
Python platform: macOS-13.6.2-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Apple M1 Max

Versions of relevant libraries:
[pip3] numpy==1.26.2
[pip3] torch==2.1.1
[pip3] torchaudio==2.1.1
[pip3] torchdata==0.7.1
[pip3] torcheval==0.0.7
[pip3] torcheval-nightly==2023.12.21
[pip3] torchtext==0.16.1
[pip3] torchvision==0.16.1
[conda] numpy                     1.24.3          py310hb93e574_0  
[conda] numpy-base                1.24.3          py310haf87e8b_0  
[conda] torch                     2.0.1                    pypi_0    pypi
jaanli commented 8 months ago

cc @jsseely for visibility

bobakfb commented 8 months ago

@galrotem @JKSenthil Any chance one of you could look into this and the other posted by @jaanli