ooooverflow / BiSeNet

BiSeNet based on pytorch
394 stars 77 forks source link

Right Dice Loss? #41

Open hitsz-zuoqi opened 3 years ago

hitsz-zuoqi commented 3 years ago

class DiceLoss(nn.Module): def __init__(self): super().__init__() self.epsilon = 1e-5

def forward(self, output, target):
    # print(output.shape)
    # print(target.shape)

    assert output.size() == target.size(), "'input' and 'target' must have the same shape"
    # 在classes上做softmax
    output = F.softmax(output, dim=1)
    # 打平tensor
    output = flatten(output) # [num_classes,B*H*W]
    target = flatten(target) # [num_classes,B*H*W]
    # intersect = (output * target).sum(-1).sum() + self.epsilon
    # denominator = ((output + target).sum(-1)).sum() + self.epsilon

    intersect = (output * target).sum(-1)
    denominator = (output + target).sum(-1)
    # dice --(0-0.5)
    dice = intersect / denominator
    dice = torch.mean(dice)
    # 1-dice (0.5,1)???
    return 1 - dice
    # return 1 - 2. * intersect / denominator

double the intersection over union?