The ignore_index argument in e.g. the AUROC metric allows one to specify a label that will be ignored. This works great when some batch elements are to be ignored. When calling the metric, and providing a tensor as input where all entries are the ignore_index, we get an IndexError.
self.aucs[f"val_label_{i}"](label_logits[:, i].squeeze(-1), labels_target[:, i])
File "lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "lib/python3.10/site-packages/torchmetrics/metric.py", line 312, in forward
self._forward_cache = self._forward_reduce_state_update(*args, **kwargs)
File "lib/python3.10/site-packages/torchmetrics/metric.py", line 382, in _forward_reduce_state_update
batch_val = self.compute()
File "/lib/python3.10/site-packages/torchmetrics/metric.py", line 633, in wrapped_func
value = _squeeze_if_scalar(compute(*args, **kwargs))
File "lib/python3.10/site-packages/torchmetrics/classification/auroc.py", line 124, in compute
return _binary_auroc_compute(state, self.thresholds, self.max_fpr)
File "lib/python3.10/site-packages/torchmetrics/functional/classification/auroc.py", line 89, in _binary_auroc_compute
fpr, tpr, _ = _binary_roc_compute(state, thresholds, pos_label)
File "lib/python3.10/site-packages/torchmetrics/functional/classification/roc.py", line 54, in _binary_roc_compute
fps, tps, thres = _binary_clf_curve(preds=state[0], target=state[1], pos_label=pos_label)
File "lib/python3.10/site-packages/torchmetrics/functional/classification/precision_recall_curve.py", line 72, in _binary_clf_curve
tps = _cumsum(target * weight, dim=0)[threshold_idxs]
Motivation
Having batches without labels may sound counterintuitive at first, but in multitask problems this can happen quite easily, when a metric only tracks a given subtask and batches are random.
Pitch
It would be helpful if this just worked (and maybe print a warning) - maybe return 0 or nan?
Alternatives
Right now, this needs to be handled manually like
if (target == -100).all():
pass
else:
self.auc(logits, target)
or, when calling compute after some update steps
if all([len(x)==0 for x in self.auc.metric_state['preds']]):
pass
else:
self.auc.compute()
🚀 Feature
The
ignore_index
argument in e.g. theAUROC
metric allows one to specify a label that will be ignored. This works great when some batch elements are to be ignored. When calling the metric, and providing a tensor as input where all entries are theignore_index
, we get anIndexError
.Motivation
Having batches without labels may sound counterintuitive at first, but in multitask problems this can happen quite easily, when a metric only tracks a given subtask and batches are random.
Pitch
It would be helpful if this just worked (and maybe print a warning) - maybe return 0 or nan?
Alternatives
Right now, this needs to be handled manually like
or, when calling
compute
after someupdate
steps