MIC-DKFZ / Skeleton-Recall

Skeleton Recall Loss for Connectivity Conserving and Resource Efficient Segmentation of Thin Tubular Structures
Apache License 2.0
74 stars 5 forks source link

`SkeletonTransform` may cause the skeleton line extracted error for multi-class segmentation? #4

Open Yuxiang1990 opened 2 months ago

Yuxiang1990 commented 2 months ago

Hi, for multi-class segmentation, extraction skeleton line for each label may diff from extraction of binary mask followed by multiply label mask.

class SkeletonTransform(BasicTransform):
    def __init__(self, do_tube: bool = True, num_classes: int = 1):
        """
        Calculates the skeleton of the segmentation (plus an optional 2 px tube around it)
        and adds it to the dict with the key "skel"
        """
        super().__init__()
        self.do_tube = do_tube
        self.num_classes = num_classes  # needed for compatibility with 3D data
        assert self.num_classes >= 1

    def apply(self, data_dict, **params):
        seg_all = data_dict['segmentation'].numpy()
        # Add tubed skeleton GT
        seg_all_skel = np.zeros_like(seg_all, dtype=np.int16)

        for labelid in range(1, self.num_classes + 1):
            # Skeletonize
            if not np.sum(seg_all[0] == labelid) == 0:
                skel = skeletonize(seg_all[0] == labelid)
                skel = (skel > 0).astype(np.int16)
                if self.do_tube:
                    skel = dilation(skel)
                seg_all_skel[0][skel > 0] = labelid

        data_dict["skel"] = torch.from_numpy(seg_all_skel)
        return data_dict

    def apply_old(self, data_dict, **params):
        seg_all = data_dict['segmentation'].numpy()
        # Add tubed skeleton GT
        bin_seg = (seg_all > 0)
        seg_all_skel = np.zeros_like(bin_seg, dtype=np.int16)

        if not np.sum(bin_seg[0]) == 0:
            skel = skeletonize(bin_seg[0])
            skel = (skel > 0).astype(np.int16)
            if self.do_tube:
                skel = dilation(dilation(skel))
            skel *= seg_all[0].astype(np.int16)
            seg_all_skel[0] = skel

        data_dict["skel"] = torch.from_numpy(seg_all_skel)
        return data_dict
ykirchhoff commented 2 months ago

Hi @Yuxiang1990,

you are correct that skeletonizing each label individually slightly differs from the binarized skeletonization we are doing for multiclass problems. However, that is actually what we usually want, as this way the skeletons for different classes stay connected if the original segmentations were connected. Think about vessels, where you might be interested in blood flow but have different classes in your vessel tree.

Best, Yannick