Amshaker / unetr_plus_plus

[IEEE TMI-2024] UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation
Apache License 2.0
373 stars 38 forks source link

UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation

Abdelrahman Shaker*1, Muhammad Maaz1, Hanoona Rasheed1, Salman Khan1, Ming-Hsuan Yang2,3 and Fahad Shahbaz Khan1,4

Mohamed Bin Zayed University of Artificial Intelligence1, University of California Merced2, Google Research3, Linkoping University4

paper Website slides

:rocket: News


main figure

Abstract: Owing to the success of transformer models, recent works study their applicability in 3D medical segmentation tasks. Within the transformer models, the self-attention mechanism is one of the main building blocks that strives to capture long-range dependencies. However, the self-attention operation has quadratic complexity which proves to be a computational bottleneck, especially in volumetric medical imaging, where the inputs are 3D with numerous slices. In this paper, we propose a 3D medical image segmentation approach, named UNETR++, that offers both high-quality segmentation masks as well as efficiency in terms of parameters, compute cost, and inference speed. The core of our design is the introduction of a novel efficient paired attention (EPA) block that efficiently learns spatial and channel-wise discriminative features using a pair of inter-dependent branches based on spatial and channel attention. Our spatial attention formulation is efficient having linear complexity with respect to the input sequence length. To enable communication between spatial and channel-focused branches, we share the weights of query and key mapping functions that provide a complimentary benefit (paired attention), while also reducing the overall network parameters. Our extensive evaluations on five benchmarks, Synapse, BTCV, ACDC, BRaTs, and Decathlon-Lung, reveal the effectiveness of our contributions in terms of both efficiency and accuracy. On Synapse, our UNETR++ sets a new state-of-the-art with a Dice Score of 87.2%, while being significantly efficient with a reduction of over 71% in terms of both parameters and FLOPs, compared to the best method in the literature.


Architecture overview of UNETR++

Overview of our UNETR++ approach with hierarchical encoder-decoder structure. The 3D patches are fed to the encoder, whose outputs are then connected to the decoder via skip connections followed by convolutional blocks to produce the final segmentation mask. The focus of our design is the introduction of an efficient paired-attention (EPA) block. Each EPA block performs two tasks using parallel attention modules with shared keys-queries and different value layers to efficiently learn enriched spatial-channel feature representations. As illustrated in the EPA block diagram (on the right), the first (top) attention module aggregates the spatial features by a weighted sum of the projected features in a linear manner to compute the spatial attention maps, while the second (bottom) attention module emphasizes the dependencies in the channels and computes the channel attention maps. Finally, the outputs of the two attention modules are fused and passed to convolutional blocks to enhance the feature representation, leading to better segmentation masks. Architecture overview


Results

Synapse Dataset

State-of-the-art comparison on the abdominal multi-organ Synapse dataset. We report both the segmentation performance (DSC, HD95) and model complexity (parameters and FLOPs). Our proposed UNETR++ achieves favorable segmentation performance against existing methods, while being considerably reducing the model complexity. Best results are in bold. Abbreviations stand for: Spl: spleen, RKid: right kidney, LKid: left kidney, Gal: gallbladder, Liv: liver, Sto: stomach, Aor: aorta, Pan: pancreas. Best results are in bold.

Synapse Results


Qualitative Comparison

Synapse Dataset

Qualitative comparison on multi-organ segmentation task. Here, we compare our UNETR++ with existing methods: UNETR, Swin UNETR, and nnFormer. The different abdominal organs are shown in the legend below the examples. Existing methods struggle to correctly segment different organs (marked in red dashed box). Our UNETR++ achieves promising segmentation performance by accurately segmenting the organs. Synapse Qual Results

ACDC Dataset

Qualitative comparison on the ACDC dataset. We compare our UNETR++ with existing methods: UNETR and nnFormer. It is noticeable that the existing methods struggle to correctly segment different organs (marked in red dashed box). Our UNETR++ achieves favorable segmentation performance by accurately segmenting the organs. Our UNETR++ achieves promising segmentation performance by accurately segmenting the organs. ACDC Qual Results


Installation

The code is tested with PyTorch 1.11.0 and CUDA 11.3. After cloning the repository, follow the below steps for installation,

  1. Create and activate conda environment
    conda create --name unetr_pp python=3.8
    conda activate unetr_pp
  2. Install PyTorch and torchvision
    pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
  3. Install other dependencies
    pip install -r requirements.txt

Dataset

We follow the same dataset preprocessing as in nnFormer. We conducted extensive experiments on five benchmarks: Synapse, BTCV, ACDC, BRaTs, and Decathlon-Lung.

The dataset folders for Synapse should be organized as follows:

./DATASET_Synapse/
  β”œβ”€β”€ unetr_pp_raw/
      β”œβ”€β”€ unetr_pp_raw_data/
           β”œβ”€β”€ Task02_Synapse/
              β”œβ”€β”€ imagesTr/
              β”œβ”€β”€ imagesTs/
              β”œβ”€β”€ labelsTr/
              β”œβ”€β”€ labelsTs/
              β”œβ”€β”€ dataset.json
           β”œβ”€β”€ Task002_Synapse
       β”œβ”€β”€ unetr_pp_cropped_data/
           β”œβ”€β”€ Task002_Synapse

The dataset folders for ACDC should be organized as follows:

./DATASET_Acdc/
  β”œβ”€β”€ unetr_pp_raw/
      β”œβ”€β”€ unetr_pp_raw_data/
           β”œβ”€β”€ Task01_ACDC/
              β”œβ”€β”€ imagesTr/
              β”œβ”€β”€ imagesTs/
              β”œβ”€β”€ labelsTr/
              β”œβ”€β”€ labelsTs/
              β”œβ”€β”€ dataset.json
           β”œβ”€β”€ Task001_ACDC
       β”œβ”€β”€ unetr_pp_cropped_data/
           β”œβ”€β”€ Task001_ACDC

The dataset folders for Decathlon-Lung should be organized as follows:

./DATASET_Lungs/
  β”œβ”€β”€ unetr_pp_raw/
      β”œβ”€β”€ unetr_pp_raw_data/
           β”œβ”€β”€ Task06_Lung/
              β”œβ”€β”€ imagesTr/
              β”œβ”€β”€ imagesTs/
              β”œβ”€β”€ labelsTr/
              β”œβ”€β”€ labelsTs/
              β”œβ”€β”€ dataset.json
           β”œβ”€β”€ Task006_Lung
       β”œβ”€β”€ unetr_pp_cropped_data/
           β”œβ”€β”€ Task006_Lung

The dataset folders for BRaTs should be organized as follows:

./DATASET_Tumor/
  β”œβ”€β”€ unetr_pp_raw/
      β”œβ”€β”€ unetr_pp_raw_data/
           β”œβ”€β”€ Task03_tumor/
              β”œβ”€β”€ imagesTr/
              β”œβ”€β”€ imagesTs/
              β”œβ”€β”€ labelsTr/
              β”œβ”€β”€ labelsTs/
              β”œβ”€β”€ dataset.json
           β”œβ”€β”€ Task003_tumor
       β”œβ”€β”€ unetr_pp_cropped_data/
           β”œβ”€β”€ Task003_tumor

Please refer to Setting up the datasets on nnFormer repository for more details. Alternatively, you can download the preprocessed dataset for Synapse, ACDC, Decathlon-Lung, BRaTs, and extract it under the project directory.

Training

The following scripts can be used for training our UNETR++ model on the datasets:

bash training_scripts/run_training_synapse.sh
bash training_scripts/run_training_acdc.sh
bash training_scripts/run_training_lung.sh
bash training_scripts/run_training_tumor.sh

Evaluation

To reproduce the results of UNETR++:

1- Download Synapse weights and paste model_final_checkpoint.model in the following path:

unetr_pp/evaluation/unetr_pp_synapse_checkpoint/unetr_pp/3d_fullres/Task002_Synapse/unetr_pp_trainer_synapse__unetr_pp_Plansv2.1/fold_0/

Then, run

bash evaluation_scripts/run_evaluation_synapse.sh

2- Download ACDC weights and paste model_final_checkpoint.model it in the following path:

unetr_pp/evaluation/unetr_pp_acdc_checkpoint/unetr_pp/3d_fullres/Task001_ACDC/unetr_pp_trainer_acdc__unetr_pp_Plansv2.1/fold_0/

Then, run

bash evaluation_scripts/run_evaluation_acdc.sh

3- Download Decathlon-Lung weights and paste model_final_checkpoint.model it in the following path:

unetr_pp/evaluation/unetr_pp_lung_checkpoint/unetr_pp/3d_fullres/Task006_Lung/unetr_pp_trainer_lung__unetr_pp_Plansv2.1/fold_0/

Then, run

bash evaluation_scripts/run_evaluation_lung.sh

4- Download BRaTs weights and paste model_final_checkpoint.model it in the following path:

unetr_pp/evaluation/unetr_pp_lung_checkpoint/unetr_pp/3d_fullres/Task003_tumor/unetr_pp_trainer_tumor__unetr_pp_Plansv2.1/fold_0/

Then, run

bash evaluation_scripts/run_evaluation_tumor.sh

Acknowledgement

This repository is built based on nnFormer repository.

Citation

If you use our work, please consider citing:

@ARTICLE{10526382,
  title={UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation}, 
  author={Shaker, Abdelrahman M. and Maaz, Muhammad and Rasheed, Hanoona and Khan, Salman and Yang, Ming-Hsuan and Khan, Fahad Shahbaz},
  journal={IEEE Transactions on Medical Imaging}, 
  year={2024},
  doi={10.1109/TMI.2024.3398728}}

Contact

Should you have any question, please create an issue on this repository or contact me at abdelrahman.youssief@mbzuai.ac.ae.