qubvel-org / segmentation_models.pytorch

Semantic segmentation models with 500+ pretrained convolutional and transformer-based backbones.
https://smp.readthedocs.io/
MIT License
9.72k stars 1.68k forks source link

Modify Jaccard, Dice and Tversky losses #927

Closed zifuwanggg closed 1 month ago

zifuwanggg commented 2 months ago

The Jaccard, Dice and Tversky losses in losses._functional are modified based on JDTLoss.

Example

import torch
import torch.linalg as LA
import torch.nn.functional as F

torch.manual_seed(0)

b, c, h, w = 4, 3, 32, 32
dims = (0, 2, 3)

pred = torch.rand(b, c, h, w).softmax(dim=1)
soft_label = torch.rand(b, c, h, w).softmax(dim=1)
hard_label = torch.randint(low=0, high=c, size=(b, h, w))
one_hot_label = F.one_hot(hard_label, c).permute(0, 3, 1, 2)

def dice_old(x, y, dims):
    cardinality = torch.sum(x, dim=dims) + torch.sum(y, dim=dims)
    intersection = torch.sum(x * y, dim=dims)
    return 2 * intersection / cardinality

def dice_new(x, y, dims):
    cardinality = torch.sum(x, dim=dims) + torch.sum(y, dim=dims)
    difference = LA.vector_norm(x - y, ord=1, dim=dims)
    intersection = (cardinality - difference) / 2
    return 2 * intersection / cardinality

print(dice_old(pred, one_hot_label, dims), dice_new(pred, one_hot_label, dims))
print(dice_old(pred, soft_label, dims), dice_new(pred, soft_label, dims))
print(dice_old(pred, pred, dims), dice_new(pred, pred, dims))

# tensor([0.3345, 0.3310, 0.3317]) tensor([0.3345, 0.3310, 0.3317])
# tensor([0.3321, 0.3333, 0.3350]) tensor([0.8680, 0.8690, 0.8700])
# tensor([0.3487, 0.3502, 0.3544]) tensor([1., 1., 1.])

References

[1] Dice Semimetric Losses: Optimizing the Dice Score with Soft Labels. Zifu Wang, Teodora Popordanoska, Jeroen Bertels, Robin Lemmens, Matthew B. Blaschko. MICCAI 2023.

[2] Jaccard Metric Losses: Optimizing the Jaccard Index with Soft Labels. Zifu Wang, Xuefei Ning, Matthew B. Blaschko. NeurIPS 2023.