The Segment Anything Model (SAM) has recently emerged as a groundbreaking model in the field of image segmentation. Nevertheless, both the original SAM and its medical adaptations necessitate slice-by-slice annotations, which directly increase the annotation workload with the size of the dataset. We propose MedLSAM to address this issue, ensuring a constant annotation workload irrespective of dataset size and thereby simplifying the annotation process. Our model introduces a few-shot localization framework capable of localizing any target anatomical part within the body. To achieve this, we develop a Localize Anything Model for 3D Medical Images (MedLAM), utilizing two self-supervision tasks: relative distance regression (RDR) and multi-scale similarity (MSS) across a comprehensive dataset of 14,012 CT scans. We then establish a methodology for accurate segmentation by integrating MedLAM with SAM. By annotating only six extreme points across three directions on a few templates, our model can autonomously identify the target anatomical region on all data scheduled for annotation. This allows our framework to generate a 2D bounding box for every slice of the image, which are then leveraged by SAM to carry out segmentations. We conducted experiments on two 3D datasets covering 38 organs and found that MedLSAM matches the performance of SAM and its medical adaptations while requiring only minimal extreme point annotations for the entire dataset. Furthermore, MedLAM has the potential to be seamlessly integrated with future 3D SAM models, paving the way for enhanced performance.
Fig.1 The overall segmentation pipeline of MedLSAM.
torch>=1.11.0
tqdm
nibabel
scipy
SimpleITK
monaiInstallation
- Create a virtual environment
conda create -n medlsam python=3.10 -y
and activate itconda activate medlsam
- Install Pytorch
git clone https://github.com/openmedlab/MedLSAM
- Enter the MedSAM folder
cd MedLSAM
and runpip install -e .
Download MedLAM checkpoint, SAM checkpoint, MedSAM checkpoint and place them at checkpoint/medlam.pth
, checkpoint/sam_vit_b_01ec64.pth
and checkpoint/medsam_vit_b.pth
We recommend using a GPU with 12GB or more memory for inference.
Note: You can also download other CT datasets and place them any place you want. MedLSAM will automaticly apply the preprocessing procedure during the inference time, so please do not normalize the original CT images!!!
After downloading the datasets, you should sort the data into "support" and "query" groups. This does not require moving the actual image files. Rather, you need to create separate lists of file paths for each group.
For each group ("support" and "query"), perform the following steps:
Ensure that the ordering of images and labels aligns in both lists. These lists will direct MedLSAM to the appropriate files during the inference process.
The file names are not important, as long as the ordering of images and labels aligns in both lists.
Example format for the .txt files:
image.txt
/path/to/your/dataset/image_1.nii.gz
...
/path/to/your/dataset/image_n.nii.gz
label.txt
/path/to/your/dataset/label_1.nii.gz
...
/path/to/your/dataset/label_n.nii.gz
MedLAM and MedLSAM load their configurations from a .txt file. The structure of the file is as follows:
[data]
support_image_ls = config/data/StructSeg_HaN/support_image.txt
support_label_ls = config/data/StructSeg_HaN/support_label.txt
query_image_ls = config/data/StructSeg_HaN/query_image.txt
query_label_ls = config/data/StructSeg_HaN/query_label.txt
gt_slice_threshold = 10
bbox_mode = SPL
slice_interval = 2
fg_class = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22]
seg_save_path = result/npz/StructSeg
seg_png_save_path = result/png/StructSeg
[vit] net_type = vit_b
[weight] medlam_load_path = checkpoint/medlam.pth vit_load_path = checkpoint/medsam_20230423_vit_b_0.0.1.pth
Each of the parameters is explained as follows:
- `support_image_ls`: The path to the list of support image files. It is recommended to use between 3 and 10 support images.
- `support_label_ls`: The path to the list of support label files.
- `query_image_ls`: The path to the list of query image files.
- `query_label_ls`: The path to the list of query label files.
- `gt_slice_threshold`: The threshold value for ground truth slice selection.
- `bbox_mode`: The bounding box mode. It could be `SPL` (Sub-Patch Localization) or `WPL` (Whole-Patch Localization), as shown in Fig.2.
- `slice_interval`: Specifies the number of slices in a sub-patch. A smaller value results in more patches. This parameter should be of type `int`, and its value should be greater than 0. **Applicable only for Sub-Patch Localization (SPL), set to `False` for Whole-Patch Localization (WPL)**.
- `fg_class`: The list of foreground class to be used for localization and segmentation. This could be a list of integers indicating the class labels. You can only select a part of them as target classes.
- `seg_save_path`: The path to save the segmentation results in .npz format, **only required for MedLSAM**.
- `seg_png_save_path`: The path to save the segmentation results in .png format, **only required for MedLSAM**.
- `net_type`: The type of vision transformer model to be used, **only required for MedLSAM**. By default, this is set to vit_b.
- `medlam_load_path`: The path to load the pretrained MedLAM model weights.
- `vit_load_path`: The path to load the pretrained vision transformer model weights, **only required for MedLSAM**. You can change it to `checkpoint/sam_vit_b_01ec64.pth` to use the SAM model as segmentation basis.
<div align="center">
<img src="https://github.com/openmedlab/MedLSAM/raw/main/fig/wpl_spl.png" width="80%">
</div>
*Fig.2 Comparison between Whole-Patch Localization (WPL) and Sub-Patch Localization (SPL) strategies.*
### Inference
- MedLAM (**Localize any anatomy target**)
```bash
CUDA_VISIBLE_DEVICES=0 python MedLAM_Inference.py --config_file path/to/your/test_medlam_config.txt
Example:
CUDA_VISIBLE_DEVICES=0 python MedLAM_Inference.py --config_file config/test_config/test_structseg_medlam.txt
CUDA_VISIBLE_DEVICES=0 python MedLAM_Inference.py --config_file config/test_config/test_word_medlam.txt
CUDA_VISIBLE_DEVICES=0 python MedLSAM_WPL_Inference.py --config_file path/to/your/test_medlsam_config.txt
CUDA_VISIBLE_DEVICES=0 python MedLSAM_SPL_Inference.py --config_file path/to/your/test_medlsam_config.txt
Example:
CUDA_VISIBLE_DEVICES=0 python MedLSAM_WPL_Inference.py --config_file config/test_config/test_structseg_medlam_wpl_medsam.txt
CUDA_VISIBLE_DEVICES=0 python MedLSAM_WPL_Inference.py --config_file config/test_config/test_structseg_medlam_wpl_medsam.txt
CUDA_VISIBLE_DEVICES=0 python MedLSAM_SPL_Inference.py --config_file config/test_config/test_structseg_medlam_spl_sam.txt
CUDA_VISIBLE_DEVICES=0 python MedLSAM_SPL_Inference.py --config_file config/test_config/test_structseg_medlam_spl_sam.txt
train/config/ori_nii.txt
file listing the paths to the original CT nii files. (MedLAM is based on the self-supervised learning tasks and no label file is required during the training time!!!)python train/dataset_preprocess.py
. It will automatically preprocess the CT file. By default, the preprocessed CT files will be saved with a new name that appends _pre
to the original filename. For example, if your original file is named scan.nii
, the preprocessed file will be named scan_pre.nii
.pre_nii.txt
located in the train/config/
directory.
python train/train_position_full_size_with_fc.py -c train/config/train_position_full_size_with_fc.txt
train/checkpoint
. It contains both the network weights and the optimizer states. python train/extract_weights.py -p train/checkpoint/your.tar
to do this (change the tar path to your file). It will automatically extract the network weights from the checkpoint and save them as checkpoint/medlam.pth
.This project is under the CC-BY-NC 4.0 license. See LICENSE for details.
If you find this repository useful, please consider citing this paper:
@article{Lei2023medlam,
title={MedLSAM: Localize and Segment Anything Model for 3D Medical Images},
author={Wenhui Lei, Xu Wei, Xiaofan Zhang, Kang Li, Shaoting Zhang},
journal={arXiv preprint arXiv:},
year={2023}
}