ShellRedia / SAM-OCTA

35 stars 2 forks source link

SAM-OCTA

中文版README: README_zh

1.Quick Start

This project involves fine-tuning SAM using LoRA and performing segmentation tasks on OCTA images, built with PyTorch.

First, you should put a pertrained weight file in the sam_weights folder. The download link for pre-trained weights is as follows:

vit_h (default): https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

vit_l: https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth

vit_b: https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth

After testing, the required RAM for the three models is as follows: 36,248 MB, 26,154 MB, 13,467 MB. The "vit_h" is the default option. If you need to use other smaller models, please download the corresponding weights and modify the configuration in options.py.

...
parser.add_argument("-model_type", type=str, default="vit_h")
...

Use train_sam_octa.py to start fine-tuning. The warning informations will tell you which packages you should install. These packages are commonly used Python libraries without additional configuration.

python train_sam_octa.py

The dataset should be formed as OCTA-500, like this:

/datasets
    /OCTA-500
        /OCTA_3M
            /GT_Artery
                10301.bmp
                10302.bmp
                ...
            /GT_Capillary
                10301.bmp
                10302.bmp
                ...
            /GT_FAZ
            ...
            /ProjectionMaps
                /OCTA(FULL)
                    10301.bmp
                    10302.bmp
                    ...
                /OCTA(ILM_OPL)
                    10301.bmp
                    10302.bmp
                    ...
                /OCTA(OPL_BM)
                    10301.bmp
                    10302.bmp
                    ...
        /OCTA_6M
            ...

Here, I used the sample with ID 10301 from the OCTA_500 dataset of 3M FoV (Field of View) as an example. If you need the complete dataset, please contact the author of the OCTA_500 dataset.

OCTA-500's related paper: https://arxiv.org/abs/2012.07261

The results and metrics will recorded in the results folder (If it doesn't exist, it will be created).

If you need to visualize the prediction samples of results, please use the display.py file. Since the result folders are generated based on time, you may need to replace this line of code. The generated images are in the sample_display folder.

..
    test_dir = "results/2024-01-01-08-17-09/3M_LargeVessel_100_True/0/0000" # Your result dir
...

Here are some segmentation samples with prompt points, respectively the input image, the ground-truth and the prediction from left to right.

Local Model

Artery

Sample

FAZ

Sample

Global Model

RV

Sample

Capillary

Sample

2.Configuration

The project can support multiple segmentation tasks and it has two modes: global and local. In fact, the performance in the global mode is comparable to other segmentation models, while the local mode is unique to SAM-OCTA. In the options.py file, you can configure it, and below are explanations for each option:

The following are some configurations specific to the OCTA-500 dataset:

3.Others

If you find the information useful, please cite the relevant paper: https://arxiv.org/abs/2309.11758

Pretrained Weights (Baidu Cloud Storage):

Link:https://pan.baidu.com/s/1S43QadZlhT8dL8TPbA0N6g?pwd=sifh

Password:sifh

4.Instance Prediction (Supplement)

Here, I provide additional code for vessel prediction, along with explanations through text and images.

  1. Prepare an Image for Prediction. Start by preparing an image that you want to predict. In the provided code, I process the image by stacking its three channels and then duplicating it side-by-side. The duplicated version is used for manual annotation of prompt points (pure green for positive points, pure red for negative points). It looks something like this:

Sample

  1. Load the Pretrained Weights. Since this is just an example, I use a fine-tuned ViT-L model, which requires less memory and computation time. The provided weights combine both global and local prediction modes. You can download the weights from the following link:

https://pan.baidu.com/s/1iCVmPaLOWVk36YbgcQ4AOg?pwd=i54c password: i54c

Then, run the script predict.py, and the results will be saved in an automatically generated folder named prediction.

  1. In global mode, no prompt points are needed. I automatically added a fixed negative point at [-100, -100] in the code. Let's take a look at the segmentation results:

Sample

  1. In local mode, prompt points are provided on the vessels, for example:

Sample Sample

The results for the provided prompt points are as follows:

Sample Sample