LIVIAETS / boundary-loss

Official code for "Boundary loss for highly unbalanced segmentation", runner-up for best paper award at MIDL 2019. Extended version in MedIA, volume 67, January 2021.
https://doi.org/10.1016/j.media.2020.101851
MIT License
647 stars 97 forks source link

How to apply the boundary loss to 3D images both efficiently and correctly? #29

Closed xychenunc closed 11 months ago

xychenunc commented 3 years ago

Hi, thanks for sharing your code. I am trying to use the boundary loss for 3D (really high-resolution) image segmentation, but I have problems with the implementation of the loss function both efficiently and correctly. For 3D image segmentation, a popular way is to train the networks using image patches. Often the time, the training samples include image patches that belong to the background. And for these samples, a naive generalization of your implementation may give SDM that are all 0s (using eras version of the loss function). To me, this does not make sense because even if these samples do not contain any foreground voxels, the SDM should not be 0s in reality. I think it makes more sense if the SDM is calculated based on the entire images rather than image patches. How do you think about this problem?

Also, I found it pretty time consuming to calculate SDM in 3D cases. How can the time efficiency be improved?

Thanks

HKervadec commented 3 years ago

The solution is to pre-compute offline the distance map in 3D, save them into a .npy in the with the axises kxyz, with k being the class axis. Pay attention to spatial resolution at this step -- the scipy function has an extra, optional parameter for that.

Then, in the dataloader, you load the 3d distmap as is (no extra transform besides converting to a tensor), and subpatch it as the same time as the original image.

You can then do the usual multiplication between distance map and softmaxes.

This is what I implemented for the extension, but didn't had time to put it in the repo yet. I will do so soon, and then point to the exact code parts doing that, but this should already give you a rough idea on how to proceed. Let me know if you need other details in the meantime.

Hoel

xychenunc commented 3 years ago

The solution is to pre-compute offline the distance map in 3D, save them into a .npy in the with the axises kxyz, with k being the class axis. Pay attention to spatial resolution at this step -- the scipy function has an extra, optional parameter for that.

Then, in the dataloader, you load the 3d distmap as is (no extra transform besides converting to a tensor), and subpatch it as the same time as the original image.

You can then do the usual multiplication between distance map and softmaxes.

This is what I implemented for the extension, but didn't had time to put it in the repo yet. I will do so soon, and then point to the exact code parts doing that, but this should already give you a rough idea on how to proceed. Let me know if you need other details in the meantime.

Hoel

Yes, this is a solution that can solve part of the problem because for training data that are obtained via data augmentation, it seems there is no better choice other than computing the SDM on the fly. In this case, training the network efficiently can be a big issue.

BTW, I want to know if the segmentation performance is sensitive to the value of loss weight for boundary loss?

Thanks

HKervadec commented 3 years ago

With respect to the data augmentation

there is no better choice other than computing the SDM on the fly

When you refer to "on the fly", you mean to compute the distance map inside the loss function ?

The way I see it, the pre-computed distance map can be augmented as well, just like we perform the augmentation on the original ground truth. The overall code would look like this:

from pathlib import Path
from Typing import Dict, List, Tuple

from torch import Tensor
from torch.utils.data import Dataset

class DistDataset(Dataset):
        def __init__(self, *args, **kwargs):
                ...
                self.files: List[Tuple[Path, Path, Path]]

        def __getitem__(index: int) -> Dict[str, Tensor]:
                img_path, gt_path, dist_path = self.files[index]

                # ... perform the transforms here

                aug_img, aug_gt, aug_dist = augment(img, gt, dist)
                del img, gt, dist  # Avoid returning those by accident

                return {"img": aug_img,  # CWH shape
                        "gt": aug_gt, # KWH shape
                        "distmap": aug_dist}  # KWH shape

# Then in the training loop
α = 0.01
for data in train_loader:
        imgs = data["img"].to(device)  # BKWH shape
        gts = data["gt"].to(device)  # BKWH shape
        dists = data["distmap"].to(device)  # BKWH shape

        optimizer.zero_grads()

        pred_probs = softmax(net(imgs))

        dsc_loss = DiceLoss(gts, pred_probs)
        bl_loss = BoundaryLoss(dists, pred_probs)

        total_loss = dsc_loss + α * bl_loss
        total_loss.backward()
        optimizer.step()

BTW, I want to know if the segmentation performance is sensitive to the value of loss weight for boundary loss?

In our experiments, it was somewhat sensitive, but still consistently gave some improvement even with a sub-optimal value: Table 3 in our extension:

Screenshot_2020-11-29 Boundary loss for highly unbalanced segmentation - 1812 07032 pdf

Increasing and rebalancing the values were not only better in perf, but also much simpler to tune -- to me that is their main advantage.

xychenunc commented 3 years ago

Thanks for your response. I tried your loss function on our dataset, however, I have not seen improved performance till now. I want to know if the simplification from the differential form to integral form hold for 3D cases as the 2D example you show in your paper? The reason I ask this question is that I think maybe the simplification does not hold for 3D boundary and 3D surface. If it still hold, could you please clarify and send me some reference articles to show that. Thanks!

Get Outlook for iOShttps://aka.ms/o0ukef


From: Hoel KERVADEC notifications@github.com Sent: Sunday, November 29, 2020 12:40:21 PM To: LIVIAETS/boundary-loss boundary-loss@noreply.github.com Cc: Chen, Xiaoyang xychen@email.unc.edu; Author author@noreply.github.com Subject: Re: [LIVIAETS/boundary-loss] How to apply the boundary loss to 3D images both efficiently and correctly? (#29)

With respect to the data augmentation

there is no better choice other than computing the SDM on the fly

When you refer to "on the fly", you mean to compute the distance map inside the loss function ?

The way I see it, the pre-computed distance map can be augmented as well, just like we perform the augmentation on the original ground truth. The overall code would look like this:

from pathlib import Path

from Typing import Dict, List, Tuple

from torch import Tensor

from torch.utils.data import Dataset

class DistDataset(Dataset):

    def __init__(self, *args, **kwargs):

            ...

            self.files: List[Tuple[Path, Path, Path]]

    def __getitem__(index: int) -> Dict[str, Tensor]:

            img_path, gt_path, dist_path = self.files[index]

            # ... perform the transforms here

            aug_img, aug_gt, aug_dist = augment(img, gt, dist)

            del img, gt, dist  # Avoid returning those by accident

            return {"img": aug_img,  # CWH shape

                    "gt": aug_gt, # KWH shape

                    "distmap": aug_dist}  # KWH shape

Then in the training loop

α = 0.01

for data in train_loader:

    imgs = data["img"].to(device)  # BKWH shape

    gts = data["gt"].to(device)  # BKWH shape

    dists = data["distmap"].to(device)  # BKWH shape

    optimizer.zero_grads()

    pred_probs = softmax(net(imgs))

    dsc_loss = DiceLoss(gts, pred_probs)

    bl_loss = BoundaryLoss(dists, pred_probs)

    total_loss = dsc_loss + α * bl_loss

    total_loss.backward()

    optimizer.step()

BTW, I want to know if the segmentation performance is sensitive to the value of loss weight for boundary loss?

In our experiments, it was somewhat sensitive, but still consistently gave some improvement even with a sub-optimal value: Table 3 in our extensionhttps://arxiv.org/pdf/1812.07032.pdf#page=17:

[Screenshot_2020-11-29 Boundary loss for highly unbalanced segmentation - 1812 07032 pdf]https://user-images.githubusercontent.com/4191866/100549218-aa836b80-323f-11eb-80ce-8a4eaed1b952.png

Increasing and rebalancing the values were not only better in perf, but also much simpler to tune -- to me that is their main advantage.

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHubhttps://github.com/LIVIAETS/boundary-loss/issues/29#issuecomment-735429571, or unsubscribehttps://github.com/notifications/unsubscribe-auth/AHVAPZ5B43LIUQADNU6VNPTSSKBQLANCNFSM4UFGJVQQ.

xychenunc commented 3 years ago

Also, what does ‘rebalance’ mean and how to implement it? Thanks

Get Outlook for iOShttps://aka.ms/o0ukef


From: Chen, Xiaoyang xychen@email.unc.edu Sent: Sunday, November 29, 2020 2:28:53 PM To: LIVIAETS/boundary-loss reply@reply.github.com; LIVIAETS/boundary-loss boundary-loss@noreply.github.com Cc: Author author@noreply.github.com Subject: Re: [LIVIAETS/boundary-loss] How to apply the boundary loss to 3D images both efficiently and correctly? (#29)

Thanks for your response. I tried your loss function on our dataset, however, I have not seen improved performance till now. I want to know if the simplification from the differential form to integral form hold for 3D cases as the 2D example you show in your paper? The reason I ask this question is that I think maybe the simplification does not hold for 3D boundary and 3D surface. If it still hold, could you please clarify and send me some reference articles to show that. Thanks!

Get Outlook for iOShttps://aka.ms/o0ukef


From: Hoel KERVADEC notifications@github.com Sent: Sunday, November 29, 2020 12:40:21 PM To: LIVIAETS/boundary-loss boundary-loss@noreply.github.com Cc: Chen, Xiaoyang xychen@email.unc.edu; Author author@noreply.github.com Subject: Re: [LIVIAETS/boundary-loss] How to apply the boundary loss to 3D images both efficiently and correctly? (#29)

With respect to the data augmentation

there is no better choice other than computing the SDM on the fly

When you refer to "on the fly", you mean to compute the distance map inside the loss function ?

The way I see it, the pre-computed distance map can be augmented as well, just like we perform the augmentation on the original ground truth. The overall code would look like this:

from pathlib import Path

from Typing import Dict, List, Tuple

from torch import Tensor

from torch.utils.data import Dataset

class DistDataset(Dataset):

    def __init__(self, *args, **kwargs):

            ...

            self.files: List[Tuple[Path, Path, Path]]

    def __getitem__(index: int) -> Dict[str, Tensor]:

            img_path, gt_path, dist_path = self.files[index]

            # ... perform the transforms here

            aug_img, aug_gt, aug_dist = augment(img, gt, dist)

            del img, gt, dist  # Avoid returning those by accident

            return {"img": aug_img,  # CWH shape

                    "gt": aug_gt, # KWH shape

                    "distmap": aug_dist}  # KWH shape

Then in the training loop

α = 0.01

for data in train_loader:

    imgs = data["img"].to(device)  # BKWH shape

    gts = data["gt"].to(device)  # BKWH shape

    dists = data["distmap"].to(device)  # BKWH shape

    optimizer.zero_grads()

    pred_probs = softmax(net(imgs))

    dsc_loss = DiceLoss(gts, pred_probs)

    bl_loss = BoundaryLoss(dists, pred_probs)

    total_loss = dsc_loss + α * bl_loss

    total_loss.backward()

    optimizer.step()

BTW, I want to know if the segmentation performance is sensitive to the value of loss weight for boundary loss?

In our experiments, it was somewhat sensitive, but still consistently gave some improvement even with a sub-optimal value: Table 3 in our extensionhttps://arxiv.org/pdf/1812.07032.pdf#page=17:

[Screenshot_2020-11-29 Boundary loss for highly unbalanced segmentation - 1812 07032 pdf]https://user-images.githubusercontent.com/4191866/100549218-aa836b80-323f-11eb-80ce-8a4eaed1b952.png

Increasing and rebalancing the values were not only better in perf, but also much simpler to tune -- to me that is their main advantage.

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHubhttps://github.com/LIVIAETS/boundary-loss/issues/29#issuecomment-735429571, or unsubscribehttps://github.com/notifications/unsubscribe-auth/AHVAPZ5B43LIUQADNU6VNPTSSKBQLANCNFSM4UFGJVQQ.

HKervadec commented 3 years ago

Thanks for your response. I tried your loss function on our dataset, however, I have not seen improved performance till now. I want to know if the simplification from the differential form to integral form hold for 3D cases as the 2D example you show in your paper? The reason I ask this question is that I think maybe the simplification does not hold for 3D boundary and 3D surface. If it still hold, could you please clarify and send me some reference articles to show that. Thanks!

Yes the result still holds, though in 3D you need to take into account the spatial resolution of each axis, as it might differ. The updated distance computation function now looks like this:

def one_hot2dist(seg: np.ndarray, resolution: Tuple[float, float, float] = None,
                 dtype=None) -> np.ndarray:
    assert one_hot(torch.tensor(seg), axis=0)
    K: int = len(seg)

    res = np.zeros_like(seg, dtype=dtype)
    for k in range(K):
        posmask = seg[k].astype(np.bool)

        if posmask.any():
            negmask = ~posmask
            res[k] = eucl_distance(negmask, sampling=resolution) * negmask \
                - (eucl_distance(posmask, sampling=resolution) - 1) * posmask
        # The idea is to leave blank the negative classes
        # since this is one-hot encoded, another class will supervise that pixel

    return res

resolution = None correspond to sampling = (1, 1, 1)

Another thing to take into account: if the space between each slice becomes too big (like 1cm on the z axis while it is 1mm on the x and y axises), then maybe the 3D distance will not make much sense. It will depend on your application.

Also, what does ‘rebalance’ mean and how to implement it? Thanks

Rebalancing correspond to start with a high weight for the DSC loss weight, and a smaller one on the boundary loss, and to slowly shift them:

α = 0.01

for e in range(epochs):
        for data in train_loader:
                ...

                total_loss = (1 - α) * dsc_loss + α * bl_loss
                total_loss.backward()
                optimizer.step()

        α = max(α + 0.01, 0.99)
xychenunc commented 3 years ago

I tried to understand the the mathematics in your paper. It is interesting to see the beautiful connection between Eq 2 and 3. However, I found it difficult to understand your derivation to connect the two. Specifically, in the paper, you mentioned that the two can be connected using the following:

Screen Shot 2020-12-02 at 5 22 17 PM

To me, it is not obvious why the first two are equivalent because getting dD_G/dq is not a constant for the second term after the minus sign in Screen Shot 2020-12-02 at 5 28 04 PM is also related to q and I think cannot be easily formulated. Could you please explain more on this?

HKervadec commented 11 months ago

That one fell through the cracks (sorry for that), please feel free to re-open/reply if still relevant.