MIC-DKFZ / Skeleton-Recall

Skeleton Recall Loss for Connectivity Conserving and Resource Efficient Segmentation of Thin Tubular Structures
Apache License 2.0
59 stars 4 forks source link

[ECCV 2024] Skeleton Recall Loss for Connectivity Conserving and Resource Efficient Segmentation of Thin Tubular Structures 🩻

Overview

Welcome to the repository for the paper "Skeleton Recall Loss for Connectivity Conserving and Resource Efficient Segmentation of Thin Tubular Structures"! This repository provides the code for the implementation of the Skeleton Recall Loss integrated within the popular nnUNet framework.

Read the paper:     arXiv

News/Updates:

Introduction

Accurately segmenting thin tubular structures, such as vessels, nerves, roads, or cracks, is a crucial task in computer vision. Traditional deep learning-based segmentation approaches often struggle to preserve the connectivity of these structures. This paper introduces Skeleton Recall Loss, a novel loss function designed to enhance connectivity conservation in thin tubular structure segmentation without incurring massive computational overheads.

Key Features

Methodology

The Skeleton Recall Loss operates by performing a tubed skeletonization on the ground truth segmentation and then computing a soft recall loss against the predicted segmentation output. This circumvents the costly calculation of a differentiable skeleton.

Tubed Skeletonization

  1. Binarization: Convert the ground truth segmentation mask to binary form.
  2. Skeleton Extraction: Compute the skeleton using efficient methods for 2D and 3D inputs.
  3. Tubular Dilation: Enlarge the skeleton using a dilation process to create a tubed skeleton.
  4. Class Assignment: For multi-class problems, assign parts of the skeleton to their respective classes.

In the code the Tubed Skeletonization is done during dataloading, see the code.

Soft Recall Loss

Full Loss calculation:

\mathcal{L} = \mathcal{L}_{Dice} + \mathcal{L}_{CE} + w \cdot \mathcal{L}_{SkelRecall}

You can change the weight of the additional Skeleton Recall Loss term by modifying the value of self.weight_srec in the nnUNetTrainerSkeletonRecall

Experimental Setup

The method is validated on several public datasets featuring thin structures, including:

Usage

Installation

Check out the official nnUNet installation instructions

TL;DR

Clone the repository and install the required dependencies:

git clone https://github.com/MIC-DKFZ/skeleton-recall.git
cd skeleton-recall
pip install -e .

nnU-Net needs to know where you intend to save raw data, preprocessed data and trained models. For this you need to set a few environment variables. Please follow the instructions here.

Integration into existing nnUNet installation

For now, if you'd like to incorporate Skeleton Recall Loss into your existing nnUNetv2 installation, you would first need copy the nnUNetTrainerSkeletonRecall class. You also have to integrate the skeletonization process during data loading, which you can find here, as well as the custom loss function here and the compound loss combination here. Integration into the official nnUNet repo is currently discussed.

Training

To train a model using Skeleton Recall Loss with nnUNet, run:

for 2D:

nnUNetv2_train DATASET_NAME_OR_ID 2d FOLD -tr nnUNetTrainerSkeletonRecall

for 3D:

nnUNetv2_train DATASET_NAME_OR_ID 3d_fullres FOLD -tr nnUNetTrainerSkeletonRecall

Citation

If you use this code in your research, please cite our paper:

@article{kirchhoff2024skeleton,
  title={Skeleton Recall Loss for Connectivity Conserving and Resource Efficient Segmentation of Thin Tubular Structures},
  author={Kirchhoff, Yannick and Rokuss, Maximilian and Roy, Saikat and others},
  journal={European Conference on Computer Vision},
  year={2024}
}

Happy coding! πŸš€

Acknowledgements

nnU-Net is developed and maintained by the Applied Computer Vision Lab (ACVL) of Helmholtz Imaging and the Division of Medical Image Computing at the German Cancer Research Center (DKFZ).