2024 0424 ref https://github.com/bowang-lab/MedSAM/tree/LiteMedSAM
0427 暂时搁置,处理行政事务,看完manba,再回看sam
A lightweight version of MedSAM for fast training and inference. The model was trained with the following two states:
TinyViT
from the MedSAM image encoder ViT
by imposing the image embedding outputs to be the sameViT
with TinyViT
and fine-tune the whole pipelineThe codebase is tested with: Ubuntu 20.04
| Python 3.10
| CUDA 11.8
| Pytorch 2.1.2
conda create -n medsam python=3.10 -y
and activate it conda activate medsam
git clone -b LiteMedSAM https://github.com/bowang-lab/MedSAM/
cd MedSAM
and run pip install -e .
work_dir/LiteMedSAM
.demo_scribble.zip
herepython CVPR24_LiteMedSAM_infer.py -i test_demo/imgs/ -o test_demo/segs
docker build -f Dockerfile -t litemedsam .
Note: don't forget the
.
in the end
Run the docker on the testing demo images
docker container run -m 8G --name litemedsam --rm -v $PWD/test_demo/imgs/:/workspace/inputs/ -v $PWD/test_demo/litemedsam-seg/:/workspace/outputs/ litemedsam:latest /bin/bash -c "sh predict.sh"
Note: please run
chmod -R 777 ./*
if you run intoPermission denied
error.
Save docker
docker save litemedsam | gzip -c > litemedsam.tar.gz
python evaluation/compute_metrics.py -s test_demo/litemedsam-seg -g test_demo/gts -csv_dir ./metrics.csv
Download the Lite-MedSAM checkpoint and put it under the current directory.
Download the demo dataset. This tutorial assumes it is unzipped it to data/FLARE22Train/
.
Run the pre-processing script to convert the dataset to npz
format:
python pre_CT_MR.py \
-img_path data/FLARE22Train/images \ ## path to training images
-img_name_suffix _0000.nii.gz \ ## extension of training images
-gt_path data/FLARE22Train/labels \ ## path to training labels
-gt_name_suffix .nii.gz \ ## extension of training labels
-output_path data \ ## path to save the preprocessed data
-num_workers 4 \ ## number of workers for preprocessing
-modality CT \ ## modality of the preprocessed data
-anatomy Abd \ ## anatomy of the preprocessed data
-window_level 40 \ ## window level for CT
-window_width 400 \ ## window width for CT
--save_nii ## Also save the preprocessed data in nii.gz format for visual inspection in other software
MedSAM_train
, the last 10 for testing, saved in MedSAM_test
.python pre_CT_MR.py -h
.Convert the training npz
to npy
format for training:
python npz_to_npy.py \
-npz_dir data/MedSAM_train \ ## path to the preprocessed npz training data
-npy_dir data/npy \ ## path to save the converted npy data for training
-num_workers 4 ## number of workers for conversion in parallel
The training pipeline requires about 10GB GPU memory with a batch size of 4
To train Lite-MedSAM on a single GPU, run:
python train_one_gpu.py \
-data_root data/MedSAM_train \
-pretrained_checkpoint lite_medsam.pth \
-work_dir work_dir \
-num_workers 4 \
-batch_size 4 \
-num_epochs 10
To resume interrupted training from a checkpoint, run:
python train_one_gpu.py \
-data_root data/MedSAM_train \
-resume work_dir/medsam_lite_latest.pth \
-work_dir work_dir \
-num_workers 4 \
-batch_size 4 \
-num_epochs 10
For additional command line arguments, see python train_one_gpu.py -h
.
To fine-tune Lite-MedSAM on multiple GPUs, run:
python train_multi_gpus.py \
-i data/npy \ ## path to the training dataset
-task_name MedSAM-Lite-Box \
-pretrained_checkpoint lite_medsam.pth \
-work_dir ./work_dir_ddp \
-batch_size 16 \
-num_workers 8 \
-lr 0.0005 \
--data_aug \ ## use data augmentation
-world_size <WORLD_SIZE> \ ## Total number of GPUs will be used
-node_rank 0 \ ## if training on a single machine, set to 0
-init_method tcp://<MASTER_ADDR>:<MASTER_PORT>
Alternatively, you can use the provided train_multi_gpus.sh
script to train on multiple GPUs. To resume interrupted training from a checkpoint, add -resume <your_work_dir>
to the command line arguments instead of the checkpoint path for multi-GPU training;
the script will automatically find the latest checkpoint in the work directory. For additional command line arguments, see python train_multi_gpus.py -h
.
The inference script assumes the testing data have been converted to npz
format.
To run inference on the 3D CT FLARE22 dataset, run:
python inference_3D.py \
-data_root data/npz/MedSAM_test/CT_Abd \ ## preprocessed npz data
-pred_save_dir ./preds/CT_Abd \
-medsam_lite_checkpoint_path work_dir/medsam_lite_latest.pth \
-num_workers 4 \
--save_overlay \ ## save segmentation overlay on the input image
-png_save_dir ./preds/CT_Abd_overlay \ ## only used when --save_overlay is set
--overwrite ## overwrite existing predictions, default continue from existing predictions
For additional command line arguments, see python inference_3D.py -h
.
We also provide a script to run inference on the 2D images inference_2D.py
, whose usage is the same as the 3D script.
The npz
format is used to store both 2D and 3D images (focusing on the ROI), along with their corresponding ground truth masks. In the case of 3D images, spacings are also included. All these data elements are compactly packed into a single npz
file. This format is primarily used for distributing our training and validation datasets. Notably, for the validation dataset, bounding boxes are provided in place of ground truth label masks.
In contrast, the npy
format stores each 2D image or individual slice of a 3D image along with its label mask in separate files. This format is utilized when loading data for training purposes.
If you encounter difficulties loading a trained model's checkpoint for inference, we recommend users first try using the extract_weights.py
script located under MedSAM/utils/
. This script is for extracting weights from your existing checkpoint and save them into a new checkpoint file.
To use this script, execute the following command in your terminal:
python extract_weights.py \
-from_pth <YOUR_CHECKPOINT_PATH> \
-to_pth <NEW_CHECKPOINT_PATH>
Replace <YOUR_CHECKPOINT_PATH>
with the path to your saved trained model checkpoint, and <NEW_CHECKPOINT_PATH>
with the desired path for the new checkpoint file. Once you have executed this command and created the new checkpoint, it should be ready for use in inference tasks.
We thank the authors of MobileSAM and TinyViT for making their source code publicly available.