LIVIAETS / boundary-loss

Official code for "Boundary loss for highly unbalanced segmentation", runner-up for best paper award at MIDL 2019. Extended version in MedIA, volume 67, January 2021.
https://doi.org/10.1016/j.media.2020.101851
MIT License
652 stars 98 forks source link

Could you write the code more clear? #3

Closed John1231983 closed 5 years ago

John1231983 commented 5 years ago

Thanks for sharing a very good idea. I am looking at the surface loss function https://github.com/LIVIAETS/surface-loss/blob/108bd9892adca476e6cdf424124bc6268707498e/losses.py#L74

I have some questions:

  1. How to generate dist_maps from the binary image? I did not find it in your project. The paper mentioned distance_transform_edt, do you need to normalize the level set to 0 -1, otherwise distance map value may be 0 to 10 or 100. It is too high

  2. If I do it for 3 classes segmentation. So C =3, then I must compute distance map of each class: background (class 0), class1, and class2 then store it in BxCxHxW. Is it right?

  3. This coding style does not like pytorch. For example, multipled = einsum("bcwh,bcwh->bcwh", pc, dc) can simplfy by multipled = pc*dc or intersection: Tensor = w * einsum("bcwh,bcwh->bc", pc, tc) by intersection = w * pc* tc

  4. Why we need to onehot the distance map. I think the distance map has size of BxCxHxW and prediction also has size of BxCxHxW then it can be multiplied directly

https://github.com/LIVIAETS/surface-loss/blob/108bd9892adca476e6cdf424124bc6268707498e/losses.py#L82 Thanks so much

HKervadec commented 5 years ago

Hey,

The function transforming the one-hot encoding to distance maps is there https://github.com/LIVIAETS/surface-loss/blob/108bd9892adca476e6cdf424124bc6268707498e/utils.py#L198 It will compute one distance map per class.

To address your questions directly:

otherwise distance map value may be 0 to 10 or 100. It is too high

Well that is the point of this paper ; to have a much higher loss for far away pixels.

If I do it for 3 classes segmentation. So C =3, then I must compute distance map of each class: background (class 0), class1, and class2 then store it in BxCxHxW. Is it right?

Yes. Note that we are still experimenting with the supervision of the background class ; treat it as a another class or simply ignore it as we do in the binary case. This is why my losses have the self.idc parameter, to select all or a few classes to supervise. To proper choice will depend on the dataset you are using ; I think it won't matter for balanced problems, but for highly unbalanced problems supervising the background class might be harmful.

This coding style does not like pytorch.

I use einsum extensively, because it self-documents the code and the way the data is represented. This way, I do not have to track manually what shape my tensors have ; which is great when you come back to your code after a few months. It is also convenient when you need to multiply and sum at once, as for intersection: Tensor = w * einsum("bcwh,bcwh->bc", pc, tc).

Type hints are also great when you keep switching between Tensors, NumPy arrays and floats, as they will all behave differently.

Why we need to onehot the distance map.

I think you slightly misunderstood. I make sure the distance map is not a valid one_hot encoding (one_hot being only 0 and 1s and sum to 1 across the C axis) ; just a way to check I did not mix up the ground truth and the distance map.

John1231983 commented 5 years ago

Thank you for your reply. I spend one day to convert your code to three class problem. I have a segmented image size of BxHxW, and the value is from 0 to 2. First, I convert it to onehot vector using the code


class One_Hot(nn.Module):
    def __init__(self, depth):
        super(One_Hot, self).__init__()
        self.depth = depth
        self.ones = torch.sparse.torch.eye(depth).cuda()

    def forward(self, X_in):
        n_dim = X_in.dim()
        output_size = X_in.size() + torch.Size([self.depth])
        num_element = X_in.numel()
        X_in = X_in.data.long().view(num_element)
        out = Variable(self.ones.index_select(0, X_in)).view(output_size)
        return out.permute(0, -1, *range(1, n_dim)).squeeze(dim=2).float()

    def __repr__(self):
        return self.__class__.__name__ + "({})".format(self.depth)

Then, I can use your function one_hot2dist directly. This is modied the boundary loss with three class problem

class BoundaryLoss(nn.Module):
    def __init__(self, n_classes):
        super(BoundaryLoss, self).__init__()
        self.n_classes = n_classes
        self.one_hot_encoder = One_Hot(n_classes).forward

    def forward(self, input, target):
        batch_size = input.size(0)
        input = F.softmax(input, dim=1).view(batch_size, self.n_classes, -1)
        target = self.one_hot_encoder(target).contiguous().view(batch_size, self.n_classes, -1)
        dist = one_hot2dist(target.data.cpu().numpy())
        dist= torch.from_numpy(dist)
        dist = dist.cuda()
        score = torch.mean(input * dist)
        score =  score / (float(batch_size) * float(self.n_classes))

        return score

Could you check my implementation is correct or not? Thanks

HKervadec commented 5 years ago

Are you on track, but several things:

You can simply chain the class2one_hot and one_hot2dist functions from utils.py to get all the distances maps for each classes. This part shows how I go from the loaded image of shape wh to C distance maps with shape cwh

I usually put that part in the dataloader as it takes some time and the data needs to come back to CPU ; and we can pre-compute it. Then during the main training loop, you have your distance maps directly available and can feed it to my implementation; it already handles multi-class.

For you, you would need to do something like this:

surface_loss = SurfaceLoss(idc=[0, 1, 2])  
# Or [1, 2] if you do not care about supervising the background

for input_image, dist_maps in dataloader:
    # input_image: bwh
    # dist_maps: bcwh
    optimizer.zero_grad()

    output_logits = net(input_image)  # bcwh
    output_softmaxes = F.softmax(output_logits, dim=1)  # bcwh

    loss = surface_loss(outputs_logits, dist_maps, None)  
    # Extra parameter is left-over for compatibility reasons, some of my other works have an extra-parameter
    loss.backward()
    optimizer.step()

Concerning your implementation, you do not need to divide by b * c after you perform the mean of the products. Apart from that I think it is correct, but as stated before, much of the work can be moved in the dataloader to speed things up.

Let me know if you have other questions or if I was not clear, I will be happy to help.

John1231983 commented 5 years ago

@HKervadec : It worked now. However, it hurts my performance. I think I did wrong thing in your purpose. Your purpose use distance map to guide the segmentation loss. The distance map is computed based on the boundary of the object. Hence, I think the output of your network is boundary instead of whole subject region. For example, a pixel belongs to the foreground region but it is in the center, so the distance map of the pixel to the boundary is far then it has high-value weight (high distance). When we use the surface loss, the pixel will tend to high surface loss and will become background. Only pixels are near boundary will has low surface loss. Hence, I guess the output of your network must be boundary instead of whole regions likes UNet. Am I right?

HKervadec commented 5 years ago

Hence, I guess the output of your network must be boundary instead of whole regions likes UNet. Am I right?

No. If you go back to the paper, Equation (5), you will see that we use the softmax probabilities (s_\theta) from the network and the level-set \phi_G. I invite you to re-read section 2 thoroughly, as it explains quite in details how we ended up with this formulation. Some papers use two branches, one for region segmentation, while the other has a regression layer to predict the distance to the boundary. Then, both losses are combined. But with our formulation, we do not need that as the distance information is already included.

However, it hurts my performance.

There are several possible explanations for that:

All those points are still active research problems ; the paper introduces the formulation and shows that it has a positive effect on highly unbalanced segmentation problems, but there is still a lot of unexplored territory.

Since this is not a code problem anymore, I think we could close this issue, and continue by e-mail the collaboration to apply this boundary loss to other settings.

John1231983 commented 5 years ago

Some papers use two branches, one for region segmentation, while the other has a regression layer to predict the distance to the boundary. Then, both losses are combined.

Thanks. Could you provide some papers link or title? Thanks. I guess my problem is that the object is not so high imbalance