Lightning-AI / torchmetrics

Machine learning metrics for distributed, scalable PyTorch applications.
https://lightning.ai/docs/torchmetrics/
Apache License 2.0
2.15k stars 408 forks source link

Segmentation IOU compute Ignore some tagged values that don't need to be recorded (such as 255) #2747

Open woldier opened 2 months ago

woldier commented 2 months ago

πŸš€ Feature

when we compute IOU

import torch

_ = torch.manual_seed(0)
from torchmetrics.segmentation import MeanIoU

miou = MeanIoU(num_classes=3)
preds = torch.randint(0, 2, (5,))
target = torch.as_tensor((0, 1, 2, 0, 255))  # An index of 255 is a tag to be ignored.
miou(preds, target)
>>> This will result in an error

Motivation

When I generate the sample pairs, the opposite mask (assuming 3 classes), but not all pixels in the entire mask should be classified into a particular class, so I set these pixels to 255. The pixel is then ignored in the loss calculation using torch.nn.CrossEntropyLoss(ignore_index=255). However, the IOU calculation does not have this feature, which leads to errors in the IOU calculation, so I wondered if it could be made to support the ignore_index parameter as well, to ignore certain pixels.

Pitch

import torch
_ = torch.manual_seed(0)
from torchmetrics.segmentation import MeanIoU
miou = MeanIoU(num_classes=3, ignore_index=255)  # support ignore_index param to ignore index 255
preds = torch.randint(0, 2, (5,))
target = torch.as_tensor((0, 1, 2, 0, 255))  # An index of 255 is a tag to be ignored.
miou(preds, target)

Alternatives

https://github.com/Lightning-AI/torchmetrics/blob/62d9d32280ba365f1c2c14e0bd8a5adc959a1a6e/src/torchmetrics/functional/segmentation/mean_iou.py#L42

https://github.com/Lightning-AI/torchmetrics/blob/62d9d32280ba365f1c2c14e0bd8a5adc959a1a6e/src/torchmetrics/functional/segmentation/mean_iou.py#L52-L55

def _mean_iou_update(
    preds: Tensor,
    target: Tensor,
    num_classes: int,
    include_background: bool = False,
    input_format: Literal["one-hot", "index"] = "one-hot",
    ignore_index=255

) -> Tuple[Tensor, Tensor]:
    ...

    if input_format == "index":
        preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1)
        mask = torch.where(target == ignore_index)  #  Add removal of ignored labels
        target[mask] = 0
        target = torch.nn.functional.one_hot(target, num_classes=num_classes)
        target[mask] = 0  # set ont-hot to zero-hot from ignored labels
        target = target.movedim(-1, 1)
  ...

Additional context

github-actions[bot] commented 2 months ago

Hi! thanks for your contribution!, great first issue!