The official repo of "SAM-Med3D: Towards General-purpose Segmentation Models for Volumetric Medical Images".
medim_infer.py
.SAM-Med3D-turbo
is released now. We fine-tuned it on 44 datasets (list) to improve the performance. Hope this update can help you π.Note: Currently, labels are required to generate prompt points for inference.
First, set up your environment with the following commands:
conda create --name sammed3d python=3.10
conda activate sammed3d
pip install light-the-torch && ltt install torch
pip install torchio opencv-python-headless matplotlib prefetch_generator monai edt medim
Then, use medim_infer.py
to test the inference:
python medim_infer.py
If you want to run inference on your own data, refer to medim_infer.py
for more details. You can simply modify the paths in the script to use your own data. Here's the main logic:
''' 1. read and pre-process your input data '''
img_path = "./test_data/kidney_right/AMOS/imagesVal/amos_0013.nii.gz"
gt_path = "./test_data/kidney_right/AMOS/labelsVal/amos_0013.nii.gz"
category_index = 3 # the index of your target category in the gt annotation
output_dir = "./test_data/kidney_right/AMOS/pred/"
roi_image, roi_label, meta_info = data_preprocess(img_path, gt_path, category_index=category_index)
''' 2. prepare the pre-trained model with local path or huggingface url '''
ckpt_path = "https://huggingface.co/blueyo0/SAM-Med3D/blob/main/sam_med3d_turbo.pth"
# or you can use the local path like: ckpt_path = "./ckpt/sam_med3d_turbo.pth"
model = medim.create_model("SAM-Med3D",
pretrained=True,
checkpoint_path=ckpt_path)
''' 3. infer with the pre-trained SAM-Med3D model '''
roi_pred = sam_model_infer(model, roi_image, roi_gt=roi_label)
''' 4. post-process and save the result '''
output_path = osp.join(output_dir, osp.basename(img_path).replace(".nii.gz", "_pred.nii.gz"))
data_postprocess(roi_pred, meta_info, output_path, img_path)
print("result saved to", output_path)
(we recommend fine-tuning with SAM-Med3D pre-trained weights from link)
To train the SAM-Med3D model on your own data, follow these steps:
Note: You can easily get PyTorch SAM-Med3D model with pre-trained weights from huggingface use
MedIM
.ckpt_path = "https://huggingface.co/blueyo0/SAM-Med3D/blob/main/sam_med3d_turbo.pth" model = medim.create_model("SAM-Med3D", pretrained=True, checkpoint_path=ckpt_path)
Download the checkpoint from ckpt section and move the pth file into SAM_Med3D/ckpt/
(We recommand to use SAM-Med3D-turbo.pth
).
Ensure that your training data is organized according to the structure shown in the data/medical_preprocessed
directories. The target file structures should be like the following:
data/medical_preprocessed
βββ adrenal
β βββ ct_WORD
β β βββ imagesTr
β β β βββ word_0025.nii.gz
β β β βββ ...
β β βββ labelsTr
β β β βββ word_0025.nii.gz
β β β βββ ...
βββ ...
If the original data are in the nnU-Net style, follow these steps:
For a nnU-Net style dataset, the original file structure should be:
Task010_WORD βββ imagesTr β βββ word_0025_0000.nii.gz β βββ ... βββ labelsTr β βββ word_0025.nii.gz β βββ ...
Then you should resample and convert the masks into binary. (You can use script for nnU-Net folder)
data/train βββ adrenal β βββ ct_WORD β β βββ imagesTr β β β βββ word_0025.nii.gz β β β βββ ... β β βββ labelsTr β β β βββ word_0025.nii.gz (binary label) β β β βββ ... βββ liver β βββ ct_WORD β β βββ imagesTr β β β βββ word_0025.nii.gz β β β βββ ... β β βββ labelsTr β β β βββ word_0025.nii.gz (binary label) β β β βββ ... βββ ...
Then, modify the utils/data_paths.py
according to your own data.
img_datas = [
"data/train/adrenal/ct_WORD",
"data/train/liver/ct_WORD",
...
]
Run bash train.sh
to execute the following command in your terminal:
python train.py --multi_gpu --task_name ${tag}
This will start the training process of the SAM-Med3D model on your prepared data. If you use only one GPU, remove the --multi_gpu
flag.
The key options are listed below:
Hint: Use the --checkpoint
to set the pre-trained weight path, the model will be trained from scratch if no ckpt in the path is found!
Prepare your own dataset and refer to the samples in data/validation
to replace them according to your specific scenario.
Then you can simply run bash val.sh
to quickly validate SAM-Med3D on your data. Or you can use bash infer.sh
to generate full-volume results for your application.
Make sure the masks are processed into the one-hot format (have only two values: the main image (foreground) and the background). We highly recommend using the spacing of 1.5mm
for the best experience.
python validation.py --seed 2023\
-vp ./results/vis_sam_med3d \
-cp ./ckpt/sam_med3d_turbo.pth \
-tdp ./data/medical_preprocessed -nc 1 \
--save_name ./results/sam_med3d.py
Sliding-window Inference (experimental): To extend the application scenario of SAM-Med3D and support more choices for full-volume inference. We provide the sliding-window mode here within inference.py
.
python inference.py --seed 2024\
-cp ./ckpt/sam_med3d_turbo.pth \
-tdp ./data/medical_preprocessed -nc 1 \
--output_dir ./results --task_name test_amos_move \
--sliding_window --save_image_and_gt
<output_dir>/<task_name>
output_dir
, plz ensure your disk has enough free space when you turn on thisFor validation of SAM and SAM-Med2D on 3D volumetric data, you can refer to scripts/val_sam.sh
and scripts/val_med2d.sh
for details.
Hint: We also provide a simple script sum_result.py
to help summarize the results from files like ./results/sam_med3d.py
.
Our most recommended version is SAM-Med3D-turbo
Model | Google Drive | Baidu NetDisk |
---|---|---|
SAM-Med3D | Download | Download (pwd:r5o3) |
SAM-Med3D-organ | Download | Download (pwd:5t7v) |
SAM-Med3D-brain | Download | Download (pwd:yp42) |
SAM-Med3D-turbo | Download | Download (pwd:l6ol) |
Other checkpoints are available with their official link: SAM and SAM-Med2D.
Model | Prompt | Resolution | Inference Time (s) | Overall Dice |
---|---|---|---|---|
SAM | N points | 1024Γ1024ΓN | 13 | 16.15 |
SAM-Med2D | N points | 256Γ256ΓN | 4 | 36.83 |
SAM-Med3D | 1 point | 128Γ128Γ128 | 2 | 38.65 |
SAM-Med3D | 10 points | 128Γ128Γ128 | 6 | 49.02 |
SAM-Med3D-turbo | 1 points | 128Γ128Γ128 | 6 | 76.27 |
SAM-Med3D-turbo | 10 points | 128Γ128Γ128 | 6 | 80.71 |
Note: Quantitative comparison of different methods on our evaluation dataset. Here, N denotes the count of slices containing the target object (10 β€ N β€ 200). Inference time is calculated with N=100, excluding the time for image processing and simulated prompt generation.
Model | Prompt | A&T | Bone | Brain | Cardiac | Muscle | Lesion | Unseen Organ | Unseen Lesion |
---|---|---|---|---|---|---|---|---|---|
SAM | N points | 19.93 | 17.85 | 29.73 | 8.44 | 3.93 | 11.56 | 12.14 | 8.88 |
SAM-Med2D | N points | 50.47 | 32.70 | 36.00 | 40.18 | 43.85 | 24.90 | 19.36 | 44.87 |
SAM-Med3D | 1 point | 46.12 | 33.30 | 49.14 | 61.04 | 53.78 | 39.56 | 23.85 | 40.53 |
SAM-Med3D | 10 points | 58.61 | 43.52 | 54.01 | 68.50 | 69.45 | 47.87 | 29.05 | 48.44 |
SAM-Med3D-turbo | 1 points | 80.76 | 83.38 | 43.74 | 87.12 | 89.74 | 58.06 | 35.99 | 44.22 |
SAM-Med3D-turbo | 10 points | 85.42 | 85.34 | 61.27 | 90.97 | 91.62 | 64.80 | 48.10 | 62.72 |
Note: Comparison from the perspective of anatomical structure and lesion. A&T represents Abdominal and Thorax targets. N denotes the count of slices containing the target object (10 β€ N β€ 200).
@misc{wang2023sammed3d,
title={SAM-Med3D},
author={Haoyu Wang and Sizheng Guo and Jin Ye and Zhongying Deng and Junlong Cheng and Tianbin Li and Jianpin Chen and Yanzhou Su and Ziyan Huang and Yiqing Shen and Bin Fu and Shaoting Zhang and Junjun He and Yu Qiao},
year={2023},
eprint={2310.15161},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
This project is released under the Apache 2.0 license.
(If the QRCode is expired, please contact the WeChat account: EugeneYonng or Small_dark8023οΌplease note with "add sammed3d wechat"/θ―·ε€ζ³¨βsammed3dδΊ€ζ΅ηΎ€β.)
BTW, welcome to follow our Zhihu official account, we will share more information on medical imaging there.