MIC-DKFZ / nnUNet

Apache License 2.0
5.92k stars 1.76k forks source link

How to make weights to decrease from 1 to 0 as the epoch increases? #1804

Open Overflowu7 opened 1 year ago

Overflowu7 commented 1 year ago

I want the weight_ce=1, weight_dice=1 to decrease from 1 to 0 as the epoch increases, but if I add the epcoh parameter to the _build_loss function, there are other coupling issues in the run_training. If I call self.current_epoch directly, it will always have a value of 0. I don't know how to solve my problem, what ideas should I use to solve this problem please, because I don't know the overall architecture very well yet!

   loss = DC_and_CE_loss({'batch_dice': self.configuration_manager.batch_dice,
                               'smooth': 1e-5, 'do_bg': False, 'ddp': self.is_ddp}, {}, weight_ce=1, weight_dice=1,
                              ignore_label=self.label_manager.ignore_label, dice_class=MemoryEfficientSoftDiceLoss)
Kobalt93 commented 9 months ago

Hi Overflowu7, Could you please provide your modified _build_loss and DC_and_CE_loss functions?

Overflowu7 commented 9 months ago

ok here is my code:


class DC_and_CE_and_TIOU_Loss(nn.Module):
    def __init__(self, soft_dice_kwargs, ce_kwargs, ti_kwargs, weight_ce=1, weight_dice=1, weight_ti=1e-6, weight_bou=0,
                 ignore_label=None,
                 dice_class=SoftDiceLoss):
        """
        Weights for CE and Dice do not need to sum to one. You can set whatever you want.
        :param soft_dice_kwargs:
        :param ce_kwargs:
        :param ti_kwargs:
        :param weight_ce:
        :param weight_dice:
        :param weight_ti:
        """
        super(DC_and_CE_and_TIOU_Loss, self).__init__()
        if ignore_label is not None:
            ce_kwargs['ignore_index'] = ignore_label

        self.weight_dice = weight_dice
        self.weight_ce = weight_ce
        self.weight_ti = weight_ti
        self.weight_bou = weight_bou
        self.ignore_label = ignore_label

        self.ce = RobustCrossEntropyLoss(**ce_kwargs)
        self.dc = dice_class(apply_nonlin=softmax_helper_dim1, **soft_dice_kwargs)
        self.ti = BTI_Loss(**ti_kwargs)
        self.bou = BoundaryDoULoss(n_classes=5)

    def forward(self, net_output: torch.Tensor, target: torch.Tensor):
        """
        target must be b, c, x, y(, z) with c=1
        :param net_output:
        :param target:
        :return:
        """
        # self.adjust_weights(self.epoch)
        if self.ignore_label is not None:
            assert target.shape[1] == 1, 'ignore label is not implemented for one hot encoded target variables ' \
                                         '(DC_and_CE_loss)'
            mask = (target != self.ignore_label).bool()
            # remove ignore label from target, replace with one of the known labels. It doesn't matter because we
            # ignore gradients in those areas anyway
            target_dice = torch.clone(target)
            target_dice[target == self.ignore_label] = 0
            num_fg = mask.sum()
        else:
            target_dice = target
            mask = None

        dc_loss = self.dc(net_output, target_dice, loss_mask=mask) \
            if self.weight_dice != 0 else 0
        ce_loss = self.ce(net_output, target[:, 0].long()) \
            if self.weight_ce != 0 and (self.ignore_label is None or num_fg > 0) else 0

        bti_loss = self.ti(net_output, target) if self.weight_ti != 0 else 0
        bou_loss = self.bou(net_output, target)
        # print(self.weight_dice,self.weight_dice)
        # print("现在是:",self.epoch)
        result = self.weight_ce * ce_loss + self.weight_ti * bti_loss + self.weight_dice * dc_loss + self.weight_bou * bou_loss
        return result
    def _build_loss(self):
        # if self.label_manager.has_regions:
        #     loss = DC_and_BCE_loss({},
        #                            {'batch_dice': self.configuration_manager.batch_dice,
        #                             'do_bg': True, 'smooth': 1e-5, 'ddp': self.is_ddp},
        #                            use_ignore_label=self.label_manager.ignore_label is not None,
        #                            dice_class=MemoryEfficientSoftDiceLoss)
        # else:
        #     loss = DC_and_CE_loss({'batch_dice': self.configuration_manager.batch_dice,
        #                            'smooth': 1e-5, 'do_bg': False, 'ddp': self.is_ddp}, {}, weight_ce=1, weight_dice=1,
        #                           ignore_label=self.label_manager.ignore_label, dice_class=MemoryEfficientSoftDiceLoss)
        patch_size = self.configuration_manager.patch_size
        dim = len(patch_size)
        connectivity = 26
        lambda_ti = 1e-6
        lambda_bou = 0
        inclusion_list = []
        exclusion_list = [[1, 2, 3, 4], [[1, 2, 3], [4]], [[1, 2], [3]]]

        inclusion_list = self.make_tensors(inclusion_list, self.device)
        exclusion_list = self.make_tensors(exclusion_list, self.device)

        loss = DC_and_CE_and_TIOU_Loss(
            {'batch_dice': self.configuration_manager.batch_dice, 'smooth': 1e-5, 'do_bg': False, 'ddp': self.is_ddp},
            {},
            {'dim': dim, 'connectivity': connectivity, 'inclusion': inclusion_list, 'exclusion': exclusion_list,
             'min_thick': 1},
            weight_ce=1, weight_dice=1, weight_ti=lambda_ti, weight_bou=lambda_bou,
            ignore_label=self.label_manager.ignore_label,
            dice_class=MemoryEfficientSoftDiceLoss)

        deep_supervision_scales = self._get_deep_supervision_scales()

        # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
        # this gives higher resolution outputs more weight in the loss
        if self.enable_deep_supervision:
            deep_supervision_scales = self._get_deep_supervision_scales()
            weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))])
            weights[-1] = 0

            # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
            weights = weights / weights.sum()
            # now wrap the loss
            loss = DeepSupervisionWrapper(loss, weights)
        return loss

I want to keep weight_ce as 1, weight_dc decrease from 1 to 0 and weight _bou follows the growth of epoch from 0 to 1. Because I wasn't quite sure about the coupling between nnUNetrainer and other parts of the code. So all my attempts to change the code failed.