cchen-cc / MA-SAM

PyTorch implementation for MA-SAM
Apache License 2.0
130 stars 6 forks source link

MA-SAM: Modality-agnostic SAM Adaptation for 3D Medical Image Segmentation

This is a PyTorch implementation of the paper MA-SAM: Modality-agnostic SAM Adaptation for 3D Medical Image Segmentation.

Overview of MA-SAM framework

We introduce a modality-agnostic SAM adaptation framework, named as MA-SAM, that is applicable to various volumetric and video medical data. Our method has been comprehensively evaluated on four medical image segmentation tasks, by using 10 public datasets across CT, MRI, and surgical video data. Without using any prompt, our method consistently outperforms various state-of-the-art 3D approaches, surpassing nnU-Net by 0.9%, 2.6%, and 9.9% in Dice for CT multi-organ segmentation, MRI prostate segmentation, and surgical scene segmentation respectively. Our model also demonstrates strong generalization, and excels in challenging tumor segmentation when prompts are used.

Usage

Environmental Requirements

Installation

Clone this repository and then install the dependencies.

git clone https://github.com/cchen-cc/MA-SAM.git
conda create -n masam python=3.10.12
conda activate masam
conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia
cd MA-SAM
pip install -r requirements.txt

Data Preparation

In this dataset_split file, we provide the dataset splits that are used in our work.

Training

Before start, please download SAM pre-trained model weights: SAM ViT_H, SAM ViT_L, SAM ViT_B, and save them under proper folders. Then go to the folder MA-SAM, and start the training:

cd MA-SAM
python train.py --root_path <Your data directory> --output <Your output directory> --ckpt <Your SAM pre-trained model directory>

We use 8 A100 80G GPUs to train our full model. To reduce the memory consumption, you may consider change the backbone from ViT_H to ViT_L or ViT_B. To do so, you would need to change the arguments --vit_name to 'vit_l' or 'vit_b' and load the correct SAM pre-trained weights for --ckpt. You may also consider reduce the number of consecutive slices. To do so, you would need to make according changes for the data pre-processing and evaluation. However, using smaller backbone or reduce the number of consecutive slices would lead to a decrease in performance. We do not recommend to reduce the batch size, which would make the model difficult to converge.

Inference

We provide our trained model for reproducing our results on BTCV datasets. To perform inference with the trained MA-SAM model, use the following command

python test.py --adapt_ckpt <Your MA-SAM model directory> --data_path <Your data directory> --ckpt <Your SAM pre-trained model directory> --is_savenii

Running this command will output the Dice evaluation metrics for your model. The argument --is_savenii will create a folder with the same name as your MA-SAM model directory (without the .pth postfix of course) to save the corresponding .nii prediction files.

Acknowledgments

Our code is based on SAMed, FacT, and Segment Anything. We appreciate the authors for their great works.

Citation

If you find the code useful for your research, please cite our paper.

@article{chen2024ma,
  title={Ma-sam: Modality-agnostic sam adaptation for 3d medical image segmentation},
  author={Chen, Cheng and Miao, Juzheng and Wu, Dufan and Zhong, Aoxiao and Yan, Zhiling and Kim, Sekeun and Hu, Jiang and Liu, Zhengliang and Sun, Lichao and Li, Xiang and others},
  journal={Medical Image Analysis},
  volume={98},
  pages={103310},
  year={2024},
  publisher={Elsevier}
}

Note