Open hitsz-zuoqi opened 3 years ago
class DiceLoss(nn.Module): def __init__(self): super().__init__() self.epsilon = 1e-5
def forward(self, output, target): # print(output.shape) # print(target.shape) assert output.size() == target.size(), "'input' and 'target' must have the same shape" # 在classes上做softmax output = F.softmax(output, dim=1) # 打平tensor output = flatten(output) # [num_classes,B*H*W] target = flatten(target) # [num_classes,B*H*W] # intersect = (output * target).sum(-1).sum() + self.epsilon # denominator = ((output + target).sum(-1)).sum() + self.epsilon intersect = (output * target).sum(-1) denominator = (output + target).sum(-1) # dice --(0-0.5) dice = intersect / denominator dice = torch.mean(dice) # 1-dice (0.5,1)??? return 1 - dice # return 1 - 2. * intersect / denominator
double the intersection over union?
class DiceLoss(nn.Module): def __init__(self): super().__init__() self.epsilon = 1e-5
double the intersection over union?