Lightning-AI / torchmetrics

Machine learning metrics for distributed, scalable PyTorch applications.
https://lightning.ai/docs/torchmetrics/
Apache License 2.0
2.14k stars 408 forks source link

`BinaryPrecisionRecallCurve` computes wrong value if used with logits, even though the docstring says this is supported #2329

Open NoahAtKintsugi opened 9 months ago

NoahAtKintsugi commented 9 months ago

🐛 Bug

The object oriented BinaryPrecisionRecallCurve can compute substantially incorrect values if logits are passed as predictions instead of probabilities, even though the docstring says this is ok. The underlying functional version binary_precision_recall_curve seems to work correctly in both cases. Both versions attempt to convert the logits to probabilities by passing them through a sigmoid if any values are outside of the range [0, 1]. In the object oriented case this condition is incorrectly checked independently for each batch, rather than for the metric as a whole. Consequently some batches may have sigmoid applied to their scores, while others do not, resulting in an incorrect curve for the dataset as a whole.

To Reproduce

import torch
import torchmetrics

b1 = torchmetrics.AUROC(task="binary")
b2 = torchmetrics.AUROC(task="binary")

score = torch.tensor([-1.0, 0.0, 1.0])
label = torch.tensor([1.0, 0.0, 1.0])

b1.update(score, label)  # pass all three score/label pairs in as a single batch

for s, l in zip(score, label):
    b2.update(torch.tensor([s]), torch.tensor([l]))  # pass score/label pairs in one at a time

assert b1.compute().item() == 0.5
assert b2.compute().item() == 1.0

Expected behavior

I would expect both AUCs to equal 0.5, as computed with scikit-learn using sklearn.metrics.roc_auc_score(label, score).

Environment

Additional context

The bug is on line https://github.com/Lightning-AI/torchmetrics/blob/master/src/torchmetrics/classification/precision_recall_curve.py#L165 -- as it is currently written the function _binary_precision_recall_curve_format should only be called on the full dataset, not on individual batches. Otherwise the behavior is wrong if some batches have all scores in the range [0, 1] but other batches do not.

Some possible solutions are: (1) update the docs not to allow for logits in the object oriented interface, since the behavior is correct for probabilities; (2) don't try to automatically infer whether to apply sigmoid -- my choice, but would be a breaking change (3) refactor _binary_precision_recall_curve_format and accept that if any values are found which require sigmoid, then all values from past batches need to have sigmoid applied (this would be tricky in the case where thresholds are specified because the scores are not kept around).

github-actions[bot] commented 9 months ago

Hi! thanks for your contribution!, great first issue!