Lightning-AI / torchmetrics

Torchmetrics - Machine learning metrics for distributed, scalable PyTorch applications.
https://lightning.ai/docs/torchmetrics/
Apache License 2.0
2.07k stars 395 forks source link

Option to return non-standardized AUC value when `max_fpr` is set for BinaryAUROC #2259

Open Adversarian opened 9 months ago

Adversarian commented 9 months ago

🚀 Feature

Adding a simple keyword argument to control where McClish's standardization is applied over AUC or not. Alternatively, could return a pair of standardized and non-standardized AUCs when the argument is set.

Motivation

Currently, BinaryAUROC only returns the standardized AUC according to McClish's method. However sometimes the non-standardized version might be requested. In my specific case, when I set max_fpr to 0.2 I was not expecting BinaryAUROC to return values over 0.2 to me as I expected 0.2 to be the theoretical max pAUC.

Pitch

Instead of offloading the process of de-standardizing the pAUC scores to the user, this can be done within torchmetrics with the addition of a simple keyword argument. Furthermore, the existence of this keyword argument and it's corresponding docstring entry makes it more concise to the user what kind of output they should be expecting from the method. Additionally, the docstring should explicitly explain what the standard AUC means by either providing an external link or preferably by providing the actual formulation of the standardization process.

Alternatives

The de-standardization process can be done out-of-the-box by the user by reverting the normalization process described by McClish, 1989. This has two immediate drawbacks:

Additional context

I would be happy to submit a PR for this myself if the pitch is approved

github-actions[bot] commented 9 months ago

Hi! thanks for your contribution!, great first issue!

SkafteNicki commented 8 months ago

Hi @Adversarian, thanks for raising this issue (and sorry for not getting back to you before now). I am not completly sure what you are proposing in this issue. The max_fpr feature was proposed a long time ago and is just implemented to be the exact same as in sklearn: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html you are completely right that we should probably include a reference to the specific standardization being used.

Maybe the problem is that I am no expert in the standardization process, so maybe you could either provide more details on what you are proposing and in particular what changes to the user interface you are proposing?

Adversarian commented 8 months ago

Hi @SkafteNicki, sorry for the late response as I am just coming back from a vacation.

I understand that most of the implementations in torchmetrics are meant to mirror their counterparts from existing libraries and a cornerstone of my proposal was maintaining this integrity while offering further functionality for users who might need it.

I will attempt to describe a prototype of what I'm proposing as a reference. Here is the current implementation of the function _binary_auroc_compute, taken from here verbatim:

def _binary_auroc_compute(
    state: Union[Tensor, Tuple[Tensor, Tensor]],
    thresholds: Optional[Tensor],
    max_fpr: Optional[float] = None,
    pos_label: int = 1,
) -> Tensor:
    fpr, tpr, _ = _binary_roc_compute(state, thresholds, pos_label)
    if max_fpr is None or max_fpr == 1 or fpr.sum() == 0 or tpr.sum() == 0:
        return _auc_compute_without_check(fpr, tpr, 1.0)

    _device = fpr.device if isinstance(fpr, Tensor) else fpr[0].device
    max_area: Tensor = tensor(max_fpr, device=_device)
    # Add a single point at max_fpr and interpolate its tpr value
    stop = torch.bucketize(max_area, fpr, out_int32=True, right=True)
    weight = (max_area - fpr[stop - 1]) / (fpr[stop] - fpr[stop - 1])
    interp_tpr: Tensor = torch.lerp(tpr[stop - 1], tpr[stop], weight)
    tpr = torch.cat([tpr[:stop], interp_tpr.view(1)])
    fpr = torch.cat([fpr[:stop], max_area.view(1)])

    # Compute partial AUC
    partial_auc = _auc_compute_without_check(fpr, tpr, 1.0)

    # McClish correction: standardize result to be 0.5 if non-discriminant and 1 if maximal
    min_area: Tensor = 0.5 * max_area**2
    return 0.5 * (1 + (partial_auc - min_area) / (max_area - min_area))

As you can see, near the end of the function there's a procedure which is described as the "McClish correction" (otherwise referred to as McClish standardization). Essentially, what this does is that it ensures that the AUROC returned by the function is consistently maxed out at the value 1 even when max_fpr is set. Now when max_fpr is set, the actual range of AUROC lies in [0, max_fpr] but McClish's correction maps this range to the [0.5, 1] range so that its behavior is consistent with the cases when max_fpr isn't set. However, some people (yours truly included!), may actually like to work with the real AUROC values instead of the standardized/corrected versions.

Now as is evident from the code snippet above, reversing the correction process is trivial because max_area and min_area are constants. However, I do not understand the need for this reversion to be offloaded to the user when the un-corrected Partial AUC (when max_fpr is set) is already calculated and ready in the partial_auc variable. This causes unnecessary error propagation (albeit small) and requires the user to be familiar with the standardization process for which there is surprisingly little material on the internet.

My proposal then is twofold:

  1. Without breaking backward/scikit-learn API compatibility, we add a new boolean keyword argument (return_standardized for instance) which defaults to True and does either of the following:
    • if return_standardized is unset, then instead of returning only standardized_partial_auc return the tuple (partial_auc, standardized_partial_auc).
    • if return_standardized is unset, then only return partial_auc instead of standardized_partial_auc.

where partial_auc is defined in the code snippet above and

standardized_partial_auc = 0.5 * (1 + (partial_auc - min_area) / (max_area - min_area))
  1. The documentation of the function would include further explanation and possibly a link to a reference regarding what the standardization/correction process described by McClish entails. (this targets binary_auroc from the functional API and BinaryAUROC from the modular API)

If there's anything I can further elaborate on, please let me know.