Lightning-AI / torchmetrics

Machine learning metrics for distributed, scalable PyTorch applications.
https://lightning.ai/docs/torchmetrics/
Apache License 2.0
2.14k stars 408 forks source link

Make `ignore_index` work when all batch elements are to be ignored #2685

Open fteufel opened 3 months ago

fteufel commented 3 months ago

🚀 Feature

The ignore_index argument in e.g. the AUROC metric allows one to specify a label that will be ignored. This works great when some batch elements are to be ignored. When calling the metric, and providing a tensor as input where all entries are the ignore_index, we get an IndexError.

    self.aucs[f"val_label_{i}"](label_logits[:, i].squeeze(-1), labels_target[:, i])
  File "lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "lib/python3.10/site-packages/torchmetrics/metric.py", line 312, in forward
    self._forward_cache = self._forward_reduce_state_update(*args, **kwargs)
  File "lib/python3.10/site-packages/torchmetrics/metric.py", line 382, in _forward_reduce_state_update
    batch_val = self.compute()
  File "/lib/python3.10/site-packages/torchmetrics/metric.py", line 633, in wrapped_func
    value = _squeeze_if_scalar(compute(*args, **kwargs))
  File "lib/python3.10/site-packages/torchmetrics/classification/auroc.py", line 124, in compute
    return _binary_auroc_compute(state, self.thresholds, self.max_fpr)
  File "lib/python3.10/site-packages/torchmetrics/functional/classification/auroc.py", line 89, in _binary_auroc_compute
    fpr, tpr, _ = _binary_roc_compute(state, thresholds, pos_label)
  File "lib/python3.10/site-packages/torchmetrics/functional/classification/roc.py", line 54, in _binary_roc_compute
    fps, tps, thres = _binary_clf_curve(preds=state[0], target=state[1], pos_label=pos_label)
  File "lib/python3.10/site-packages/torchmetrics/functional/classification/precision_recall_curve.py", line 72, in _binary_clf_curve
    tps = _cumsum(target * weight, dim=0)[threshold_idxs]

Motivation

Having batches without labels may sound counterintuitive at first, but in multitask problems this can happen quite easily, when a metric only tracks a given subtask and batches are random.

Pitch

It would be helpful if this just worked (and maybe print a warning) - maybe return 0 or nan?

Alternatives

Right now, this needs to be handled manually like

if (target == -100).all():
    pass
else:
    self.auc(logits, target)

or, when calling compute after some update steps

if all([len(x)==0 for x in self.auc.metric_state['preds']]):
    pass
else:
   self.auc.compute()
github-actions[bot] commented 3 months ago

Hi! thanks for your contribution!, great first issue!