pytorch / torcheval

A library that contains a rich collection of performant PyTorch model metrics, a simple interface to create new metrics, a toolkit to facilitate metric computation in distributed training and tools for PyTorch model evaluations.
https://pytorch.org/torcheval
Other
211 stars 46 forks source link

The score computed by `multiclass_f1_score` for binary classification is wrong. It is not f1 score but accuracy. #191

Closed LittletreeZou closed 3 months ago

LittletreeZou commented 7 months ago

🐛 Describe the bug

The score computed by multiclass_f1_score for binary classification is wrong. It is not f1 score but accuracy, as shown in following code:

import torch
from torcheval.metrics.functional import multiclass_f1_score, binary_f1_score

actual = torch.repeat_interleave(torch.tensor([1, 0]), repeats=torch.tensor([100, 100]))
pred = torch.repeat_interleave(torch.tensor([1, 0, 1, 0]), repeats=torch.tensor([55, 45, 34, 66]))

multiclass_f1_score(pred, actual, num_classes=2)
# tensor(0.6050)

(actual == pred).sum()/200
# tensor(0.6050)

binary_f1_score(pred, actual)
# tensor(0.5820)

Versions

torcheval 0.0.7

MattBrth commented 3 months ago

I have the same issue ! Please fix this

bobakfb commented 3 months ago

Hi @LittletreeZou and @MattBrth ! Thank you both for contributing to torcheval!

I think there is a misunderstanding of multi-class metrics going on here! In most cases Binary-X and Multiclass-X with 2 classes are not the same thing. In the binary setting there is only 1 class, and each example is either in or not in that class. In the multiclass setting, there are two classes and you compute the f1 score for each (first calling 0 a positive label and computing true positives, false positive, etc... against the label actual = 0, then calling 1 a positive label and doing the same).

In other words, what you're getting is the average of two f1 scores here.

Note: if you pass average=None to multiclass_f1_score, you will recover the binary result for class 1 in one of the outputs

>>> metric = MulticlassF1Score(num_classes=2, average=None)
>>> metric.update(pred, actual)
>>> print(metric.compute())
tensor([0.6256, **0.5820**])

Also note, this is standard behavior, e.g. following sklearn

>>> actual = torch.repeat_interleave(torch.tensor([1, 0]), repeats=torch.tensor([100, 100]))
>>> pred = torch.repeat_interleave(torch.tensor([1, 0, 1, 0]), repeats=torch.tensor([55, 45, 34, 66]))
>>> input_tensors = [torch.argmax(t, dim=1) if t.ndim == 2 else t for t in pred]
>>> target_tensors = list(actual)
>>> target_np = torch.stack(target_tensors).flatten().numpy()
>>> input_np = torch.stack(input_tensors).flatten().numpy()
>>> f1_score(target_np, input_np, average='micro')
0.605