Open BaronDuan opened 10 months ago
def Dice_loss(inputs, target, beta=1, smooth = 1e-5): n, c, h, w = inputs.size() nt, ht, wt, ct = target.size() if h != ht and w != wt: inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)
temp_inputs = torch.softmax(inputs.permute(0, 2, 3, 1).contiguous().view(n, -1, c), -1) temp_target = target.view(n, -1, ct) # 计算dice loss tp = torch.sum(temp_target[..., :-1] * temp_inputs, axis=[0, 1]) fp = torch.sum(temp_inputs , axis=[0, 1]) - tp fn = torch.sum(temp_target[..., :-1] , axis=[0, 1]) - tp score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth) dice_loss = 1 - torch.mean(score) return dice_loss
def Dice_loss(inputs, target, beta=1, smooth = 1e-5): n, c, h, w = inputs.size() nt, ht, wt, ct = target.size() if h != ht and w != wt: inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)