The MulticlassRecall function with a top_k > 1 and the average parameter set to "macro" is not behaving as anticipated. Ideally, as top_k increases, the results should increase. However, on certain occasions, this isn't the case.
In a multiclass scenario, when calculating top-k results, the number of false positives (fp) tends to increase with a higher value of k. This, in turn, augments the value of weights.sum(-1, keepdim=True) and consequently reduces the final recall@k. Also wondering when calculate macro avg recall, should it be weights[tp + fn == 0] = 0.0 ?
🐛 Bug
The MulticlassRecall function with a top_k > 1 and the average parameter set to "
macro
" is not behaving as anticipated. Ideally, as top_k increases, the results should increase. However, on certain occasions, this isn't the case.To Reproduce
code sample
```py import torch from torchmetrics.classification import MulticlassRecall num_classes = 200 preds = torch.randn(5, num_classes).softmax(dim=-1) target = torch.randint(num_classes, (5,)) recall_val_top1=MulticlassRecall(num_classes=num_classes, top_k=1, average="macro") recall_val_top5=MulticlassRecall(num_classes=num_classes, top_k=5, average="macro") recall_val_top10=MulticlassRecall(num_classes=num_classes, top_k=10, average="macro") recall_val_top100=MulticlassRecall(num_classes=num_classes, top_k=100, average="macro") print(recall_val_top1(preds, target), recall_val_top5(preds, target), recall_val_top10(preds, target),recall_val_top100(preds, target)) ``` it returns `tensor(0.) tensor(0.0357) tensor(0.0213) tensor(0.0154)`Expected behavior
The results is expected to rise as k grows.
Environment
pip
):Additional context
I checked the function _adjust_weights_safe_divide where it calculates the recall for function _precision_recall_reduce and am unsure of this snippets:
In a multiclass scenario, when calculating top-k results, the number of false positives (fp) tends to increase with a higher value of k. This, in turn, augments the value of
weights.sum(-1, keepdim=True)
and consequently reduces the final recall@k. Also wondering when calculate macro avg recall, should it beweights[tp + fn == 0] = 0.0 ?