Lightning-AI / torchmetrics

Torchmetrics - Machine learning metrics for distributed, scalable PyTorch applications.
https://lightning.ai/docs/torchmetrics/
Apache License 2.0
2.01k stars 391 forks source link

segmentation.MeanIoU is wrong #2558

Open nrudakov opened 1 month ago

nrudakov commented 1 month ago

🐛 Bug

MeanIoU scored 56(!) over validation dataset image

Let's look at the source code: https://github.com/Lightning-AI/torchmetrics/blob/596ed09b18c3d6786cb094c87da97838228461f3/src/torchmetrics/segmentation/mean_iou.py#L109

def update(self, preds: Tensor, target: Tensor) -> None:
    """Update the state with the new data."""
    intersection, union = _mean_iou_update(preds, target, self.num_classes, self.include_background)
    score = _mean_iou_compute(intersection, union, per_class=self.per_class)
    self.score += score.mean(0) if self.per_class else score.mean()

def compute(self) -> Tensor:
    """Update the state with the new data."""
    return self.score  # / self.num_batches

There are several issues there:

Obviously, that code was neither reviewed nor tested, but somehow was released.

To Reproduce

Call metric.update(y_hat, y) in validation_step Log the metric in on_validation_epoch_end

Expected behavior

MeanIoU computes correct value in [0, 1] range.

Environment

github-actions[bot] commented 1 month ago

Hi! thanks for your contribution!, great first issue!

DimitrisMantas commented 1 month ago

I think now with zero_division having been added to JaccardIndex, there's no real need for MeanIoU at all. MeanIoU also doesn't follow the classical API.

The original motivation for this was that JaccardIndex used to assign a score of 0 to absent and ignored classes so you couldn't do classwise and macro averaging correctly. Now, you can just set zero_division to NaN and average to None, and get correct class scores. From there, you could do a nanmean to get the correct macro average.

juliendenize commented 1 month ago

Hi, I also noticed that the MeanIoU was wrong during my experiments.

I developed something that seems to work on my side, which returns the same results as evaluate's mean iou however based on @DimitrisMantas I wonder if it is relevant to submit a PR. I am not familiar enough with the JaccardIndex implementation in torchmetrics.

For reference, here is the undocumented code I developed (which has not been rigorusly tested for now), let me know if submitting a PR is something interesting for you, I'd gladly contribute to this repo.

from typing import Any, Literal

import torch
from torch import Tensor
from torchmetrics import Metric

def _compute_intersection_and_union(
    preds: Tensor,
    target: Tensor,
    num_classes: int,
    include_background: bool = False,
    input_format: Literal["one-hot", "index", "predictions"] = "index",
) -> tuple[Tensor, Tensor]:
    if input_format in ["index", "predictions"]:
        if input_format == "predictions":
            preds = preds.argmax(1)
        preds = torch.nn.functional.one_hot(preds, num_classes=num_classes)
        target = torch.nn.functional.one_hot(target, num_classes=num_classes)

    if not include_background:
        preds[..., 0] = 0
        target[..., 0] = 0

    reduce_axis = list(range(1, preds.ndim - 1))
    intersection = torch.sum(torch.logical_and(preds, target), dim=reduce_axis)
    target_sum = torch.sum(target, dim=reduce_axis)
    pred_sum = torch.sum(preds, dim=reduce_axis)
    union = target_sum + pred_sum - intersection

    return intersection, union

class MeanIoU(Metric):
    def __init__(
        self,
        num_classes: int,
        include_background: bool = True,
        per_class: bool = False,
        input_format: Literal["one-hot", "index", "predictions"] = "index",
        **kwargs: Any,
    ) -> None:
        Metric.__init__(self, **kwargs)

        self.num_classes = num_classes
        self.include_background = include_background
        self.per_class = per_class
        self.input_format = input_format

        self.add_state("intersection", default=torch.zeros(num_classes), dist_reduce_fx="sum")
        self.add_state("union", default=torch.zeros(num_classes), dist_reduce_fx="sum")

    def update(self, preds: Tensor, target: Tensor) -> None:
        intersection, union = _compute_intersection_and_union(
            preds, target, self.num_classes, self.include_background, self.input_format
        )
        self.intersection += intersection.sum(0)
        self.union += union.sum(0)

    def compute(self) -> Tensor:
        iou_valid = torch.gt(self.union, 0)

        iou = torch.where(
            iou_valid,
            torch.divide(self.intersection, self.union),
            torch.nan,
        )

        if self.per_class:
            return iou
        else:
            return torch.mean(iou[iou_valid])