Open Overflowu7 opened 1 year ago
Hi Overflowu7, Could you please provide your modified _build_loss and DC_and_CE_loss functions?
ok here is my code:
class DC_and_CE_and_TIOU_Loss(nn.Module):
def __init__(self, soft_dice_kwargs, ce_kwargs, ti_kwargs, weight_ce=1, weight_dice=1, weight_ti=1e-6, weight_bou=0,
ignore_label=None,
dice_class=SoftDiceLoss):
"""
Weights for CE and Dice do not need to sum to one. You can set whatever you want.
:param soft_dice_kwargs:
:param ce_kwargs:
:param ti_kwargs:
:param weight_ce:
:param weight_dice:
:param weight_ti:
"""
super(DC_and_CE_and_TIOU_Loss, self).__init__()
if ignore_label is not None:
ce_kwargs['ignore_index'] = ignore_label
self.weight_dice = weight_dice
self.weight_ce = weight_ce
self.weight_ti = weight_ti
self.weight_bou = weight_bou
self.ignore_label = ignore_label
self.ce = RobustCrossEntropyLoss(**ce_kwargs)
self.dc = dice_class(apply_nonlin=softmax_helper_dim1, **soft_dice_kwargs)
self.ti = BTI_Loss(**ti_kwargs)
self.bou = BoundaryDoULoss(n_classes=5)
def forward(self, net_output: torch.Tensor, target: torch.Tensor):
"""
target must be b, c, x, y(, z) with c=1
:param net_output:
:param target:
:return:
"""
# self.adjust_weights(self.epoch)
if self.ignore_label is not None:
assert target.shape[1] == 1, 'ignore label is not implemented for one hot encoded target variables ' \
'(DC_and_CE_loss)'
mask = (target != self.ignore_label).bool()
# remove ignore label from target, replace with one of the known labels. It doesn't matter because we
# ignore gradients in those areas anyway
target_dice = torch.clone(target)
target_dice[target == self.ignore_label] = 0
num_fg = mask.sum()
else:
target_dice = target
mask = None
dc_loss = self.dc(net_output, target_dice, loss_mask=mask) \
if self.weight_dice != 0 else 0
ce_loss = self.ce(net_output, target[:, 0].long()) \
if self.weight_ce != 0 and (self.ignore_label is None or num_fg > 0) else 0
bti_loss = self.ti(net_output, target) if self.weight_ti != 0 else 0
bou_loss = self.bou(net_output, target)
# print(self.weight_dice,self.weight_dice)
# print("现在是:",self.epoch)
result = self.weight_ce * ce_loss + self.weight_ti * bti_loss + self.weight_dice * dc_loss + self.weight_bou * bou_loss
return result
def _build_loss(self):
# if self.label_manager.has_regions:
# loss = DC_and_BCE_loss({},
# {'batch_dice': self.configuration_manager.batch_dice,
# 'do_bg': True, 'smooth': 1e-5, 'ddp': self.is_ddp},
# use_ignore_label=self.label_manager.ignore_label is not None,
# dice_class=MemoryEfficientSoftDiceLoss)
# else:
# loss = DC_and_CE_loss({'batch_dice': self.configuration_manager.batch_dice,
# 'smooth': 1e-5, 'do_bg': False, 'ddp': self.is_ddp}, {}, weight_ce=1, weight_dice=1,
# ignore_label=self.label_manager.ignore_label, dice_class=MemoryEfficientSoftDiceLoss)
patch_size = self.configuration_manager.patch_size
dim = len(patch_size)
connectivity = 26
lambda_ti = 1e-6
lambda_bou = 0
inclusion_list = []
exclusion_list = [[1, 2, 3, 4], [[1, 2, 3], [4]], [[1, 2], [3]]]
inclusion_list = self.make_tensors(inclusion_list, self.device)
exclusion_list = self.make_tensors(exclusion_list, self.device)
loss = DC_and_CE_and_TIOU_Loss(
{'batch_dice': self.configuration_manager.batch_dice, 'smooth': 1e-5, 'do_bg': False, 'ddp': self.is_ddp},
{},
{'dim': dim, 'connectivity': connectivity, 'inclusion': inclusion_list, 'exclusion': exclusion_list,
'min_thick': 1},
weight_ce=1, weight_dice=1, weight_ti=lambda_ti, weight_bou=lambda_bou,
ignore_label=self.label_manager.ignore_label,
dice_class=MemoryEfficientSoftDiceLoss)
deep_supervision_scales = self._get_deep_supervision_scales()
# we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
# this gives higher resolution outputs more weight in the loss
if self.enable_deep_supervision:
deep_supervision_scales = self._get_deep_supervision_scales()
weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))])
weights[-1] = 0
# we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
weights = weights / weights.sum()
# now wrap the loss
loss = DeepSupervisionWrapper(loss, weights)
return loss
I want to keep weight_ce as 1, weight_dc decrease from 1 to 0 and weight _bou follows the growth of epoch from 0 to 1. Because I wasn't quite sure about the coupling between nnUNetrainer and other parts of the code. So all my attempts to change the code failed.
I want the weight_ce=1, weight_dice=1 to decrease from 1 to 0 as the epoch increases, but if I add the epcoh parameter to the _build_loss function, there are other coupling issues in the run_training. If I call self.current_epoch directly, it will always have a value of 0. I don't know how to solve my problem, what ideas should I use to solve this problem please, because I don't know the overall architecture very well yet!