Lightning-AI / torchmetrics

Machine learning metrics for distributed, scalable PyTorch applications.
https://lightning.ai/docs/torchmetrics/
Apache License 2.0
2.14k stars 408 forks source link

`top_k` for `MulticlassRecall` is not working as expected #2166

Open c23996 opened 1 year ago

c23996 commented 1 year ago

🐛 Bug

The MulticlassRecall function with a top_k > 1 and the average parameter set to "macro" is not behaving as anticipated. Ideally, as top_k increases, the results should increase. However, on certain occasions, this isn't the case.

To Reproduce

code sample ```py import torch from torchmetrics.classification import MulticlassRecall num_classes = 200 preds = torch.randn(5, num_classes).softmax(dim=-1) target = torch.randint(num_classes, (5,)) recall_val_top1=MulticlassRecall(num_classes=num_classes, top_k=1, average="macro") recall_val_top5=MulticlassRecall(num_classes=num_classes, top_k=5, average="macro") recall_val_top10=MulticlassRecall(num_classes=num_classes, top_k=10, average="macro") recall_val_top100=MulticlassRecall(num_classes=num_classes, top_k=100, average="macro") print(recall_val_top1(preds, target), recall_val_top5(preds, target), recall_val_top10(preds, target),recall_val_top100(preds, target)) ``` it returns `tensor(0.) tensor(0.0357) tensor(0.0213) tensor(0.0154)`

Expected behavior

The results is expected to rise as k grows.

Environment

Additional context

I checked the function _adjust_weights_safe_divide where it calculates the recall for function _precision_recall_reduce and am unsure of this snippets:

weights = torch.ones_like(score)
if not multilabel:
    weights[tp + fp + fn == 0] = 0.0
return _safe_divide(weights * score, weights.sum(-1, keepdim=True)).sum(-1)

In a multiclass scenario, when calculating top-k results, the number of false positives (fp) tends to increase with a higher value of k. This, in turn, augments the value of weights.sum(-1, keepdim=True) and consequently reduces the final recall@k. Also wondering when calculate macro avg recall, should it be weights[tp + fn == 0] = 0.0 ?

github-actions[bot] commented 1 year ago

Hi! thanks for your contribution!, great first issue!

iuhgnor commented 12 months ago

top_k for MulticlassAccuracy doesn't work as expected either.