HuCaoFighting / Swin-Unet

[ECCVW 2022] The codes for the work "Swin-Unet: Unet-like Pure Transformer for Medical Image Segmentation"
1.58k stars 298 forks source link

Dice loss is weird #106

Closed hdnminh closed 6 months ago

hdnminh commented 6 months ago

Hi @HuCaoFighting

I hope you have a great last week of the year!

Can you explain your implemented dice loss:

    def _dice_loss(self, score, target):
        target = target.float()
        smooth = 1e-5
        intersect = torch.sum(score * target)
        y_sum = torch.sum(target * target)
        z_sum = torch.sum(score * score)
        loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
        loss = 1 - loss
        return loss

Why are y_sum and z_sum calculated by square of target and score tensor, respectively? Following the formula, we just sum it basically without squaring itself.

hdnminh commented 6 months ago

If anyone cares this, you can read V-Net