pytorch / vision

Datasets, Transforms and Models specific to Computer Vision
https://pytorch.org/vision
BSD 3-Clause "New" or "Revised" License
16.18k stars 6.95k forks source link

Segmentation masks IoU support #5726

Open Ilyabasharov opened 2 years ago

Ilyabasharov commented 2 years ago

🚀 The feature

pairwise distance on segmentation masks like (torchvision.ops.box_iou)

Motivation, pitch

motivation: cost matrix construction in object tracking

Alternatives

pycocotools.maskUtils.iou

Additional context

No response

cc @datumbox @YosuaMichael

datumbox commented 2 years ago

@Ilyabasharov thanks for the recommendation.

I think it's an operator we could consider adding in TorchVision. If you send a PR that adds it, I'll be happy to review it.

oke-aditya commented 2 years ago

I had a thought over this. Masks are always interpreted very differently. If we see pycocotools. They actually support rle masks #4415 .

Also I'm not sure what IoU in segmentation would actually refer to? If the masks are represented as bit masks (or boolean masks) Say

ground truth
[
 [ 0  0 0 0 1]
 [ 1 0 1 0 0 ]
]
Pred mask
[
 [0 1 0 0 1 ]
 [0 1 1 0 0 ]
]

Then I think IoU is simply jaccard index between these two? This isn't the case if masks are represented as rle masks. Maybe refer detectron2 as well https://github.com/facebookresearch/detectron2/blob/6886f85baee349556749680ae8c85cdba1782d8e/detectron2/structures/masks.py#L173

As in currently our mask utils support boolean or bit masks.

vadimkantorov commented 2 years ago

Related discussion: https://github.com/pytorch/vision/issues/4415

Ilyabasharov commented 2 years ago

@datumbox oke-aditya Here is my the answer def mask_iou(mask1: torch.Tensor, mask2: torch.Tensor, ) -> torch.Tensor:

VasLem commented 1 year ago

Extending on @Ilyabasharov 's function, here is a naive extension that accepts batches:

import torch
@torch.jit.script
def mask_iou(
    mask1: torch.Tensor,
    mask2: torch.Tensor,
) -> torch.Tensor:

    """
    Inputs:
    mask1: BxNxHxW torch.float32. Consists of [0, 1]
    mask2: BxMxHxW torch.float32. Consists of [0, 1]
    Outputs:
    ret: BxNxM torch.float32. Consists of [0 - 1]
    """

    B, N, H, W = mask1.shape
    B, M, H, W = mask2.shape

    mask1 = mask1.view(B, N, H * W)
    mask2 = mask2.view(B, M, H * W)

    intersection = torch.matmul(mask1, mask2.swapaxes(1, 2))

    area1 = mask1.sum(dim=2).unsqueeze(1)
    area2 = mask2.sum(dim=2).unsqueeze(1)

    union = (area1.swapaxes(1, 2) + area2) - intersection

    ret = torch.where(
        union == 0,
        torch.tensor(0.0, device=mask1.device),
        intersection / union,
    )

    return ret
vadimkantorov commented 1 year ago

@VasLem might become even nicer to use mask1.flatten(start_dim = -2) instead of the view. though maybe some extra asserts on shape correctness would be nice as well