Welcome to the repository for the paper "Skeleton Recall Loss for Connectivity Conserving and Resource Efficient Segmentation of Thin Tubular Structures"! This repository provides the code for the implementation of the Skeleton Recall Loss integrated within the popular nnUNet framework.
Accurately segmenting thin tubular structures, such as vessels, nerves, roads, or cracks, is a crucial task in computer vision. Traditional deep learning-based segmentation approaches often struggle to preserve the connectivity of these structures. This paper introduces Skeleton Recall Loss, a novel loss function designed to enhance connectivity conservation in thin tubular structure segmentation without incurring massive computational overheads.
The Skeleton Recall Loss operates by performing a tubed skeletonization on the ground truth segmentation and then computing a soft recall loss against the predicted segmentation output. This circumvents the costly calculation of a differentiable skeleton.
In the code the Tubed Skeletonization is done during dataloading, see the code.
\mathcal{L} = \mathcal{L}_{Dice} + \mathcal{L}_{CE} + w \cdot \mathcal{L}_{SkelRecall}
You can change the weight of the additional Skeleton Recall Loss term by modifying the value of self.weight_srec
in the nnUNetTrainerSkeletonRecall
The method is validated on several public datasets featuring thin structures, including:
Check out the official nnUNet installation instructions
TL;DR
Clone the repository and install the required dependencies:
git clone https://github.com/MIC-DKFZ/skeleton-recall.git
cd skeleton-recall
pip install -e .
nnU-Net needs to know where you intend to save raw data, preprocessed data and trained models. For this you need to set a few environment variables. Please follow the instructions here.
For now, if you'd like to incorporate Skeleton Recall Loss into your existing nnUNetv2 installation, you would first need copy the nnUNetTrainerSkeletonRecall
class. You also have to integrate the skeletonization process during data loading, which you can find here, as well as the custom loss function here and the compound loss combination here. Integration into the official nnUNet repo is currently discussed.
To train a model using Skeleton Recall Loss with nnUNet, run:
for 2D:
nnUNetv2_train DATASET_NAME_OR_ID 2d FOLD -tr nnUNetTrainerSkeletonRecall
for 3D:
nnUNetv2_train DATASET_NAME_OR_ID 3d_fullres FOLD -tr nnUNetTrainerSkeletonRecall
If you use this code in your research, please cite our paper:
@article{kirchhoff2024skeleton,
title={Skeleton Recall Loss for Connectivity Conserving and Resource Efficient Segmentation of Thin Tubular Structures},
author={Kirchhoff, Yannick and Rokuss, Maximilian and Roy, Saikat and others},
journal={European Conference on Computer Vision},
year={2024}
}
Happy coding! π
nnU-Net is developed and maintained by the Applied Computer Vision Lab (ACVL) of Helmholtz Imaging and the Division of Medical Image Computing at the German Cancer Research Center (DKFZ).