Closed shijianjian closed 3 years ago
Hi, shijianjian, Thanks for your attention, but I do not know what kind of problem you have met. Firstly, ACE Loss is extended from AC Loss, Xu Chen et al. and Jun Ma et al. have proved that the AC-based loss function is useful for a single class segmentation, same with them we just demonstrate the usefulness in binary segmentation task, that is because, for the multi-class problem, there are several parameters need to be tuned carefully, maybe it can also work well, but we have not tested it. Secondly, if you want to combine it with DICE Loss, please use "mean" to please "sum", this tip I have provided in READ.ME. Finally, you can read the original paper to learn more details, maybe it can help you. Best; Xiangde.
@Luoxd1996 Thanks for the quick response. The major mistake I did there was to input logits than the sigmoid/softmax results to the loss function. I could use "mean" to integrate the loss now. Thanks your advise.
I am doing a organ segmentation task and I am currently using beta = 1
for it, which performs not as good as dice loss (dice only = 0.96, dice+ACE = 0.93). I have noticed in the paper that you mentioned that the beta is suggested to be in (2, 10) for this kind of task. Can you share your thoughts on how to select the best beta value?
Thank you!
FYI, the code:
class ACELossVM(nn.Module):
"""
Active Contour Loss
based on total variations and mean curvature
to use these methods just as constrains (combining with dice loss or ce loss)
with torch.mean() to replace torch.sum().
For instance, for curvilinear or tubular structures image segmentation tasks:
a) β (0 < β < 2) has better segmentation results
b) β (2 < β < 10) for non-tubular structures
"""
def __init__(self, u=1, a=1e-3, b=1, from_logits=True, reduction='sum') -> None:
super().__init__()
self.u = u
self.a = a
self.b = b
self.from_logits = from_logits
self.reduction = reduction
def first_derivative(self, input):
u = input
m = u.shape[2]
n = u.shape[3]
ci_0 = (u[:, :, 1, :] - u[:, :, 0, :]).unsqueeze(2)
ci_1 = u[:, :, 2:, :] - u[:, :, 0:m - 2, :]
ci_2 = (u[:, :, -1, :] - u[:, :, m - 2, :]).unsqueeze(2)
ci = torch.cat([ci_0, ci_1, ci_2], 2) / 2
cj_0 = (u[:, :, :, 1] - u[:, :, :, 0]).unsqueeze(3)
cj_1 = u[:, :, :, 2:] - u[:, :, :, 0:n - 2]
cj_2 = (u[:, :, :, -1] - u[:, :, :, n - 2]).unsqueeze(3)
cj = torch.cat([cj_0, cj_1, cj_2], 3) / 2
return ci, cj
def second_derivative(self, input, ci, cj):
u = input
# m = u.shape[2]
n = u.shape[3]
cii_0 = (u[:, :, 1, :] + u[:, :, 0, :] - 2 * u[:, :, 0, :]).unsqueeze(2)
cii_1 = u[:, :, 2:, :] + u[:, :, :-2, :] - 2 * u[:, :, 1:-1, :]
cii_2 = (u[:, :, -1, :] + u[:, :, -2, :] - 2 * u[:, :, -1, :]).unsqueeze(2)
cii = torch.cat([cii_0, cii_1, cii_2], 2)
cjj_0 = (u[:, :, :, 1] + u[:, :, :, 0] - 2 * u[:, :, :, 0]).unsqueeze(3)
cjj_1 = u[:, :, :, 2:] + u[:, :, :, :-2] - 2 * u[:, :, :, 1:-1]
cjj_2 = (u[:, :, :, -1] + u[:, :, :, -2] - 2 * u[:, :, :, -1]).unsqueeze(3)
cjj = torch.cat([cjj_0, cjj_1, cjj_2], 3)
cij_0 = ci[:, :, :, 1:n]
cij_1 = ci[:, :, :, -1].unsqueeze(3)
cij_a = torch.cat([cij_0, cij_1], 3)
cij_2 = ci[:, :, :, 0].unsqueeze(3)
cij_3 = ci[:, :, :, 0:n - 1]
cij_b = torch.cat([cij_2, cij_3], 3)
cij = cij_a - cij_b
return cii, cjj, cij
def region(self, y_pred, y_true, u=1):
label = y_true.float()
c_in = torch.ones_like(y_pred)
c_out = torch.zeros_like(y_pred)
if self.reduction == 'mean':
region_in = torch.abs(torch.mean(y_pred * ((label - c_in) ** 2)))
region_out = torch.abs(torch.mean((1 - y_pred) * ((label - c_out) ** 2)))
elif self.reduction == 'sum':
region_in = torch.abs(torch.sum(y_pred * ((label - c_in) ** 2)))
region_out = torch.abs(torch.sum((1 - y_pred) * ((label - c_out) ** 2)))
else:
raise ValueError
region = u * region_in + region_out
return region
def elastica(self, input, a=1, b=1):
ci, cj = self.first_derivative(input)
cii, cjj, cij = self.second_derivative(input, ci, cj)
beta = 1e-8
length = torch.sqrt(beta + ci ** 2 + cj ** 2)
curvature = (beta + ci ** 2) * cjj + (beta + cj ** 2) * cii - 2 * ci * cj * cij
curvature = torch.abs(curvature) / ((ci ** 2 + cj ** 2) ** 1.5 + beta)
if self.reduction == 'mean':
elastica = torch.mean((a + b * (curvature ** 2)) * torch.abs(length))
elif self.reduction == 'sum':
elastica = torch.sum((a + b * (curvature ** 2)) * torch.abs(length))
else:
raise ValueError
return elastica
def forward(self, y_pred, y_true):
if self.from_logits:
y_pred = torch.sigmoid(y_pred)
loss = self.region(y_pred, y_true, u=self.u) + self.elastica(y_pred, a=self.a, b=self.b)
return loss
Hi, shijianjian, In my understanding, there is not good guidance to select the best value of "u, a, b", it is also the main drawback of AC-based loss functions, too many parameters, you can search them in a pre-defined space. In addition, we recommend you to read Prof. Xuecheng Tai and Prof. Tony Chan's paper, named "Image segmentation using Euler’s elastica as the regularization". By the way, our work's main contributions are not how to select the best parameters, so we just evaluate the robustness in some different values. Best, Xiangde.
Hi, shijianjian, Also thanks for your re-written ACE code, we will update it and add more details about the usage later. Thank you again! Sincerely, Xiangde.
I did not do a grid search but tried with few parameter combinations. With ACE loss only, I can hardly make it work for my project. Interestingly, the DSC will not be improved after several epochs (approx. 3 epochs):
Additionally, I tried to add some randomness whilst training but also no luck there. As simple as following:
loss = self.region(y_pred, y_true, u=self.u * np.random.randint(5)) + self.elastica(
y_pred, a=self.a * np.random.randint(100), b=self.b * np.random.randint(5))
To me, ACE loss is not easy to make it work and it seems to be very sensitive to hyperparameter settings. I think this can be a critical point to have more people using ACE loss. Hope it could be improved in your next work!
Hi,shijinajian, @shijianjian ,thanks for your suggestions. I agree with your opinion, the ACE Loss is sensitive to the hyperparameters, that is because of the ACE Model property, but we have reduced them to 3. In addition, I do not know what's your task and what's your target's structure, so I do not know what's your problems. You can try other loss functions for your task, good luck. Best, Xiangde.
Hi, thanks for the repo.
When I try to use the ACELoss, the HD Loss is very high for hundreds epochs (approx. 1e+4) while the DSC is around 0.005 or so. I am wondering if I use this repo right.
What I am trying to do is multi-label segmentation, that inputting B1HW and outputting BCHW (where C is the number of classes). Previously, I used Dice loss for my task and it worked well (approx. 0.95 DSC). Currently, I simply switched from dice loss to 0.8 dice + 0.2 ACE.
I am wondering if I am using ACE correctly or not?