This is the official pytorch implementation of our MICCAI 2023 paper "UniSeg: A Prompt-driven Universal Segmentation Model as well as A Strong Representation Learner". In this paper, we propose a Prompt-Driven Universal Segmentation model (UniSeg) to segment multiple organs, tumors, and vertebrae on 3D medical images with diverse modalities and domains.
self.task
, self.task_class
, and self.total_task_num
in the UniSeg_Trainer.CUDA 11.5
Python 3.8
Pytorch 1.11.0
CuDNN 8.3.2.44
git clone https://github.com/yeerwen/UniSeg.git
cd UniSeg
Step 1:
pip install nnunet
.export nnUNet_raw_data_base="/data/userdisk0/ywye/nnUNet_raw"
export nnUNet_preprocessed="/erwen_SSD/1T/nnUNet_preprocessed"
export RESULTS_FOLDER="/data/userdisk0/ywye/nnUNet_trained_models"
Step 2:
cd Upstream
$nnUNet_raw_data_base/nnUNet_raw_data/
directory.python prepare_Kidney_Dataset.py
to normalize the name of the volumes for the Kidney dataset.python Convert_MOTS_to_nnUNet_dataset.py
to pre-process the MOTS dataset.python Convert_VerSe20_to_nnUNet_dataset.py
to pre-process the VerSe20 dataset and generate splits_final.pkl
.python Convert_Prostate_to_nnUNet_dataset.py
to pre-process the Prostate dataset and generate splits_final.pkl
.python Convert_BraTS21_to_nnUNet_dataset.py
to pre-process the BraTS21 dataset and generate splits_final.pkl
.python Convert_AutoPET_to_nnUNet_dataset.py
to pre-process the AutoPET2022 dataset and generate splits_final.pkl
.Step 3:
Upstream/nnunet
to replace nnunet
, which is installed by pip install nnunet
(the address is usually 'anaconda3/envs/your envs/lib/python3.8/site-packages/nnunet').nnUNet_plan_and_preprocess -t 91 --verify_dataset_integrity --planner3d MOTSPlanner3D
.nnUNet_plan_and_preprocess -t 37 --verify_dataset_integrity --planner3d VerSe20Planner3D
.nnUNet_plan_and_preprocess -t 20 --verify_dataset_integrity --planner3d ProstatePlanner3D
.nnUNet_plan_and_preprocess -t 21 --verify_dataset_integrity --planner3d BraTS21Planner3D
.nnUNet_plan_and_preprocess -t 11 --verify_dataset_integrity --planner3d AutoPETPlanner3D
.splits_final.pkl
of each dataset to the address of its pre-processed dataset. For example, '***/nnUNet_preprocessed/Task091_MOTS/splits_final.pkl'. Note that, to follow DoDNet, we provide splits_final.pkl
of the MOTS dataset in Upstream/MOTS_data_split/splits_final.pkl
.python merge_each_sub_dataet.py
to form a new dataset.Upstream/splits_final_11_tasks.pkl
.Upstream/run_ssl.sh
and Upstream/UniSeg_Metrics_test.py
to "***/nnUNet_trained_models/"
.***/nnUNet_trained_models/
.sh run_ssl.sh
for training (GPU Memory Cost: ~10GB, Time Cost: ~210s each epoch).cd Downstream
python Convert_BTCV_to_nnUNet_dataset.py
to pre-process the BTCV dataset and generate splits_final.pkl
.python Convert_VSseg_to_nnUNet_dataset.py
to pre-process the VS dataset and generate splits_final.pkl
.Downstream/nnunet
to replace nnunet
, which is installed by pip install nnunet
(the address is usually 'anaconda3/envs/your envs/lib/python3.8/site-packages/nnunet').nnUNet_plan_and_preprocess -t 60 --verify_dataset_integrity
.nnUNet_plan_and_preprocess -t 61 --verify_dataset_integrity
.splits_final.pkl
of two datasets to the addresses of their pre-processed datasets.Downstream/splits_final_BTCV.pkl
and Downstream/splits_final_VS.pkl
.CUDA_VISIBLE_DEVICES=0 nnUNet_n_proc_DA=32 nnUNet_train 3d_fullres UniSeg_Trainer_DS 60 0
CUDA_VISIBLE_DEVICES=0 nnUNet_n_proc_DA=32 nnUNet_train 3d_fullres UniSeg_Trainer_DS 61 0
./nnUNet_trained_models/UniSeg_Trainer/3d_fullres/Task097_11task/UniSeg_Trainer__DoDNetPlans/fold_0/
and rename them to model_final_checkpoint.model
and model_final_checkpoint.model.pkl
, respectively.cd Upstream
Upstream/nnunet
to replace nnunet
, which is installed by pip install nnunet
CUDA_VISIBLE_DEVICES=1 nnUNet_n_proc_DA=32 nnUNet_predict -i /data/userdisk0/ywye/nnUNet_raw/nnUNet_raw_data/Test/Image/ -o /data/userdisk0/ywye/nnUNet_raw/nnUNet_raw_data/Test/Predict/10/ -t 97 -m 3d_fullres -tr UniSeg_Trainer -f 0 -task_id 7 -exp_name UniSeg_Trainer -num_image 1 -modality CT -spacing 3.0,1.5,1.5
-i
: Path of the input image(s), name format of the input image: name_0000.nii.gz (name_0001.nii.gz)-o
: Path of the output mask(s)-task_id
Selected segmentation task. -1
means predicting all segmentation tasks under a specific modality. -num_image
: Channel number of the input image(s)-modality
: "CT" or "MR" (prostate) or "MR,MR,MR,MR" (brain tumors) or "CT,PET" (whole-body tumors)-spacing
: Spacing of resampled image(s)If this code is helpful for your study, please cite:
@article{ye2023uniseg,
title={UniSeg: A Prompt-driven Universal Segmentation Model as well as A Strong Representation Learner},
author={Yiwen Ye, Yutong Xie, Jianpeng Zhang, Ziyang Chen, and Yong Xia},
booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention},
pages={508--518},
year={2023},
organization={Springer}
}
The whole framework is based on nnUNet v1.
Yiwen Ye (ywye@mail.nwpu.edu.cn)