ljwztc / CLIP-Driven-Universal-Model

[ICCV 2023] CLIP-Driven Universal Model; Rank first in MSD Competition.
Other
570 stars 69 forks source link

loss.py BinaryDiceLoss use of dim = 1 when flattening the tensor. #83

Open skapoor2024 opened 2 months ago

skapoor2024 commented 2 months ago

In the provided loss function

class BinaryDiceLoss(nn.Module):
    def __init__(self, smooth=1, p=2, reduction='mean'):
        super(BinaryDiceLoss, self).__init__()
        self.smooth = smooth
        self.p = p
        self.reduction = reduction

    def forward(self, predict, target):
        assert predict.shape[0] == target.shape[0], "predict & target batch size don't match"
        predict = predict.contiguous().view(-1)
        target = target.contiguous().view(-1)

        num = torch.sum(torch.mul(predict, target), dim=1)
        den = torch.sum(predict, dim=1) + torch.sum(target, dim=1) + self.smooth

        dice_score = 2*num / den
        dice_loss = 1 - dice_score

        dice_loss_avg = dice_loss[target[:,0]!=-1].sum() / dice_loss[target[:,0]!=-1].shape[0]

        return dice_loss_avg

When we flatten the predict and target array why do we try to sum with dim = 1. Shouldn't it be dim = 0. Since there is only one dimension there . Also when predict and target dimensions are sent to the loss function, they are selected from individual batch and individual organ leaving the 3-D image for target and prediction.

ljwztc commented 2 months ago

Make sense. This should be 0. But why we encountered no errors when running this code. Maybe earlier version of Pytorch support this summation?

skapoor2024 commented 2 months ago

I believe we should make the following changes like

        predict = predict.contiguous().view(1,-1)
        target = target.contiguous().view(1,-1)

This would make the dim=1 work properly as all the d,h,w will concatenate and the final dim would be (1,dhw)