Open nrudakov opened 1 month ago
Hi! thanks for your contribution!, great first issue!
I think now with zero_division having been added to JaccardIndex, there's no real need for MeanIoU at all. MeanIoU also doesn't follow the classical API.
The original motivation for this was that JaccardIndex used to assign a score of 0 to absent and ignored classes so you couldn't do classwise and macro averaging correctly. Now, you can just set zero_division to NaN and average to None, and get correct class scores. From there, you could do a nanmean to get the correct macro average.
Hi, I also noticed that the MeanIoU was wrong during my experiments.
I developed something that seems to work on my side, which returns the same results as evaluate's mean iou however based on @DimitrisMantas I wonder if it is relevant to submit a PR. I am not familiar enough with the JaccardIndex implementation in torchmetrics.
For reference, here is the undocumented code I developed (which has not been rigorusly tested for now), let me know if submitting a PR is something interesting for you, I'd gladly contribute to this repo.
from typing import Any, Literal
import torch
from torch import Tensor
from torchmetrics import Metric
def _compute_intersection_and_union(
preds: Tensor,
target: Tensor,
num_classes: int,
include_background: bool = False,
input_format: Literal["one-hot", "index", "predictions"] = "index",
) -> tuple[Tensor, Tensor]:
if input_format in ["index", "predictions"]:
if input_format == "predictions":
preds = preds.argmax(1)
preds = torch.nn.functional.one_hot(preds, num_classes=num_classes)
target = torch.nn.functional.one_hot(target, num_classes=num_classes)
if not include_background:
preds[..., 0] = 0
target[..., 0] = 0
reduce_axis = list(range(1, preds.ndim - 1))
intersection = torch.sum(torch.logical_and(preds, target), dim=reduce_axis)
target_sum = torch.sum(target, dim=reduce_axis)
pred_sum = torch.sum(preds, dim=reduce_axis)
union = target_sum + pred_sum - intersection
return intersection, union
class MeanIoU(Metric):
def __init__(
self,
num_classes: int,
include_background: bool = True,
per_class: bool = False,
input_format: Literal["one-hot", "index", "predictions"] = "index",
**kwargs: Any,
) -> None:
Metric.__init__(self, **kwargs)
self.num_classes = num_classes
self.include_background = include_background
self.per_class = per_class
self.input_format = input_format
self.add_state("intersection", default=torch.zeros(num_classes), dist_reduce_fx="sum")
self.add_state("union", default=torch.zeros(num_classes), dist_reduce_fx="sum")
def update(self, preds: Tensor, target: Tensor) -> None:
intersection, union = _compute_intersection_and_union(
preds, target, self.num_classes, self.include_background, self.input_format
)
self.intersection += intersection.sum(0)
self.union += union.sum(0)
def compute(self) -> Tensor:
iou_valid = torch.gt(self.union, 0)
iou = torch.where(
iou_valid,
torch.divide(self.intersection, self.union),
torch.nan,
)
if self.per_class:
return iou
else:
return torch.mean(iou[iou_valid])
🐛 Bug
MeanIoU scored 56(!) over validation dataset![image](https://github.com/Lightning-AI/torchmetrics/assets/5420651/a98e66c8-48fa-4e8d-a02a-ee09d1783be6)
Let's look at the source code: https://github.com/Lightning-AI/torchmetrics/blob/596ed09b18c3d6786cb094c87da97838228461f3/src/torchmetrics/segmentation/mean_iou.py#L109
There are several issues there:
self.score
is accumulated with each call ofupdate
method.compute
method just returns the accumulated scorecompute
method is copypasted fromupdate
methodself.num_batches
is commentednum_batches
is definded on class level and not used anywhere elseObviously, that code was neither reviewed nor tested, but somehow was released.
To Reproduce
Call
metric.update(y_hat, y)
invalidation_step
Log the metric inon_validation_epoch_end
Expected behavior
MeanIoU computes correct value in [0, 1] range.
Environment