HiLab-git / ACELoss

Implementations of "Learning Euler's Elastica Model for Medical Image Segmentation"
MIT License
69 stars 11 forks source link

Unable to converge #3

Closed shijianjian closed 3 years ago

shijianjian commented 3 years ago

Hi, thanks for the repo.

When I try to use the ACELoss, the HD Loss is very high for hundreds epochs (approx. 1e+4) while the DSC is around 0.005 or so. I am wondering if I use this repo right.

What I am trying to do is multi-label segmentation, that inputting B1HW and outputting BCHW (where C is the number of classes). Previously, I used Dice loss for my task and it worked well (approx. 0.95 DSC). Currently, I simply switched from dice loss to 0.8 dice + 0.2 ACE.

I am wondering if I am using ACE correctly or not?

Luoxd1996 commented 3 years ago

Hi, shijianjian, Thanks for your attention, but I do not know what kind of problem you have met. Firstly, ACE Loss is extended from AC Loss, Xu Chen et al. and Jun Ma et al. have proved that the AC-based loss function is useful for a single class segmentation, same with them we just demonstrate the usefulness in binary segmentation task, that is because, for the multi-class problem, there are several parameters need to be tuned carefully, maybe it can also work well, but we have not tested it. Secondly, if you want to combine it with DICE Loss, please use "mean" to please "sum", this tip I have provided in READ.ME. Finally, you can read the original paper to learn more details, maybe it can help you. Best; Xiangde.

shijianjian commented 3 years ago

@Luoxd1996 Thanks for the quick response. The major mistake I did there was to input logits than the sigmoid/softmax results to the loss function. I could use "mean" to integrate the loss now. Thanks your advise.

I am doing a organ segmentation task and I am currently using beta = 1 for it, which performs not as good as dice loss (dice only = 0.96, dice+ACE = 0.93). I have noticed in the paper that you mentioned that the beta is suggested to be in (2, 10) for this kind of task. Can you share your thoughts on how to select the best beta value?

Thank you!

FYI, the code:

class ACELossVM(nn.Module):
    """
    Active Contour Loss
    based on total variations and mean curvature

    to use these methods just as constrains (combining with dice loss or ce loss)
    with torch.mean() to replace torch.sum().

    For instance, for curvilinear or tubular structures image segmentation tasks:
        a) β (0 < β < 2) has better segmentation results
        b) β (2 < β < 10) for non-tubular structures
    """
    def __init__(self, u=1, a=1e-3, b=1, from_logits=True, reduction='sum') -> None:
        super().__init__()
        self.u = u
        self.a = a
        self.b = b
        self.from_logits = from_logits
        self.reduction = reduction

    def first_derivative(self, input):
        u = input
        m = u.shape[2]
        n = u.shape[3]

        ci_0 = (u[:, :, 1, :] - u[:, :, 0, :]).unsqueeze(2)
        ci_1 = u[:, :, 2:, :] - u[:, :, 0:m - 2, :]
        ci_2 = (u[:, :, -1, :] - u[:, :, m - 2, :]).unsqueeze(2)
        ci = torch.cat([ci_0, ci_1, ci_2], 2) / 2

        cj_0 = (u[:, :, :, 1] - u[:, :, :, 0]).unsqueeze(3)
        cj_1 = u[:, :, :, 2:] - u[:, :, :, 0:n - 2]
        cj_2 = (u[:, :, :, -1] - u[:, :, :, n - 2]).unsqueeze(3)
        cj = torch.cat([cj_0, cj_1, cj_2], 3) / 2

        return ci, cj

    def second_derivative(self, input, ci, cj):
        u = input
        # m = u.shape[2]
        n = u.shape[3]

        cii_0 = (u[:, :, 1, :] + u[:, :, 0, :] - 2 * u[:, :, 0, :]).unsqueeze(2)
        cii_1 = u[:, :, 2:, :] + u[:, :, :-2, :] - 2 * u[:, :, 1:-1, :]
        cii_2 = (u[:, :, -1, :] + u[:, :, -2, :] - 2 * u[:, :, -1, :]).unsqueeze(2)
        cii = torch.cat([cii_0, cii_1, cii_2], 2)

        cjj_0 = (u[:, :, :, 1] + u[:, :, :, 0] - 2 * u[:, :, :, 0]).unsqueeze(3)
        cjj_1 = u[:, :, :, 2:] + u[:, :, :, :-2] - 2 * u[:, :, :, 1:-1]
        cjj_2 = (u[:, :, :, -1] + u[:, :, :, -2] - 2 * u[:, :, :, -1]).unsqueeze(3)

        cjj = torch.cat([cjj_0, cjj_1, cjj_2], 3)

        cij_0 = ci[:, :, :, 1:n]
        cij_1 = ci[:, :, :, -1].unsqueeze(3)

        cij_a = torch.cat([cij_0, cij_1], 3)
        cij_2 = ci[:, :, :, 0].unsqueeze(3)
        cij_3 = ci[:, :, :, 0:n - 1]
        cij_b = torch.cat([cij_2, cij_3], 3)
        cij = cij_a - cij_b

        return cii, cjj, cij

    def region(self, y_pred, y_true, u=1):
        label = y_true.float()
        c_in = torch.ones_like(y_pred)
        c_out = torch.zeros_like(y_pred)
        if self.reduction == 'mean':
            region_in = torch.abs(torch.mean(y_pred * ((label - c_in) ** 2)))
            region_out = torch.abs(torch.mean((1 - y_pred) * ((label - c_out) ** 2)))
        elif self.reduction == 'sum':
            region_in = torch.abs(torch.sum(y_pred * ((label - c_in) ** 2)))
            region_out = torch.abs(torch.sum((1 - y_pred) * ((label - c_out) ** 2)))
        else:
            raise ValueError
        region = u * region_in + region_out
        return region

    def elastica(self, input, a=1, b=1):
        ci, cj = self.first_derivative(input)
        cii, cjj, cij = self.second_derivative(input, ci, cj)
        beta = 1e-8
        length = torch.sqrt(beta + ci ** 2 + cj ** 2)
        curvature = (beta + ci ** 2) * cjj + (beta + cj ** 2) * cii - 2 * ci * cj * cij
        curvature = torch.abs(curvature) / ((ci ** 2 + cj ** 2) ** 1.5 + beta)
        if self.reduction == 'mean':
            elastica = torch.mean((a + b * (curvature ** 2)) * torch.abs(length))
        elif self.reduction == 'sum':
            elastica = torch.sum((a + b * (curvature ** 2)) * torch.abs(length))
        else:
            raise ValueError
        return elastica

    def forward(self, y_pred, y_true):
        if self.from_logits:
            y_pred = torch.sigmoid(y_pred)
        loss = self.region(y_pred, y_true, u=self.u) + self.elastica(y_pred, a=self.a, b=self.b)
        return loss
Luoxd1996 commented 3 years ago

Hi, shijianjian, In my understanding, there is not good guidance to select the best value of "u, a, b", it is also the main drawback of AC-based loss functions, too many parameters, you can search them in a pre-defined space. In addition, we recommend you to read Prof. Xuecheng Tai and Prof. Tony Chan's paper, named "Image segmentation using Euler’s elastica as the regularization". By the way, our work's main contributions are not how to select the best parameters, so we just evaluate the robustness in some different values. Best, Xiangde.

Luoxd1996 commented 3 years ago

Hi, shijianjian, Also thanks for your re-written ACE code, we will update it and add more details about the usage later. Thank you again! Sincerely, Xiangde.

shijianjian commented 3 years ago

I did not do a grid search but tried with few parameter combinations. With ACE loss only, I can hardly make it work for my project. Interestingly, the DSC will not be improved after several epochs (approx. 3 epochs):

  1. a=0.001, b=1 >>>> DSC 0.6271
  2. a=0.001, b=2 >>>> DSC 0.5432
  3. a=0.001, b=5 >>>> DSC 0.5763
  4. a=0.0001, b=1 >>>> DSC 0.6107
  5. a=0.0001, b=2 >>>> DSC 0.5763

Additionally, I tried to add some randomness whilst training but also no luck there. As simple as following:

        loss = self.region(y_pred, y_true, u=self.u * np.random.randint(5)) + self.elastica(
            y_pred, a=self.a * np.random.randint(100), b=self.b * np.random.randint(5))

To me, ACE loss is not easy to make it work and it seems to be very sensitive to hyperparameter settings. I think this can be a critical point to have more people using ACE loss. Hope it could be improved in your next work!

Luoxd1996 commented 3 years ago

Hi,shijinajian, @shijianjian ,thanks for your suggestions. I agree with your opinion, the ACE Loss is sensitive to the hyperparameters, that is because of the ACE Model property, but we have reduced them to 3. In addition, I do not know what's your task and what's your target's structure, so I do not know what's your problems. You can try other loss functions for your task, good luck. Best, Xiangde.