bubbliiiing / unet-pytorch

这是一个unet-pytorch的源码,可以训练自己的模型
MIT License
1.36k stars 259 forks source link

你好,我想问问在unet_training代码中定义的dice_loss函数中temp_target[...,:-1] 是什么意思啊,切片索引的是啥?还有torch.sum相加中的axis是啥意思? #88

Open BaronDuan opened 10 months ago

BaronDuan commented 10 months ago

def Dice_loss(inputs, target, beta=1, smooth = 1e-5): n, c, h, w = inputs.size() nt, ht, wt, ct = target.size() if h != ht and w != wt: inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)

temp_inputs = torch.softmax(inputs.permute(0, 2, 3, 1).contiguous().view(n, -1, c), -1)
temp_target = target.view(n, -1, ct)

# 计算dice loss
tp = torch.sum(temp_target[..., :-1] * temp_inputs, axis=[0, 1])
fp = torch.sum(temp_inputs                        , axis=[0, 1]) - tp
fn = torch.sum(temp_target[..., :-1]              , axis=[0, 1]) - tp

score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth)
dice_loss = 1 - torch.mean(score)
return dice_loss