King-HAW / GMS

Official repository of Generative Medical Segmentation
https://arxiv.org/abs/2403.18198
39 stars 3 forks source link

Generative Medical Segmentation

This is the official repository of Generative Medical Segmentation (GMS).

Paper | Weights

Updates

Introduction

We introduce Generative Medical Segmentation (GMS), a novel approach leveraging a generative model for image segmentation. Concretely, GMS employs a robust pre-trained Variational Autoencoder (VAE) to derive latent representations of both images and masks, followed by a mapping model that learns the transition from image to mask in the latent space. This process culminates in generating a precise segmentation mask within the image space using the pre-trained VAE decoder. The design of GMS leads to fewer learnable parameters in the model, resulting in a reduced computational burden and enhanced generalization capability. Our extensive experimental analysis across five public datasets in different medical imaging domains demonstrates GMS outperforms existing discriminative segmentation models and has remarkable domain generalization.

Overview of GMS

We use the pre-trained Stable Diffusion Variational Auto-Encoder to abtain the latent representation of input images and reconstruct the predicted segmentation mask from the latent space. For the reconstruction capability of SD VAE, please check the supplementary material. The latent mapping model was purely developed on CNN, and does not contain down-sampling layers to prevent information loss.

overview

Getting Started

Environment Setup

We provide a conda env file that contains all the required dependencies. You can use it to create and activate a conda environment.

conda env create -f environment.yaml
conda activate GMS

Or use the virtual environment:

python3 -m venv GMS
source GMS/bin/activate
pip install -r requirements.txt

Prepare datasets

We evaluate GMS on five public datasets: BUS, BUSI, GlaS, HAM10000 and Kvasir-Instrument. The structure of the Dataset folder should be as follows:

Dataset/
├── bus
│   ├── bus_train_test_names.pkl
│   ├── images
│   │   ├── 000001.png
│   │   ...
│   │   └── 000310.png
│   └── masks
│       ├── 000001.png
│       ...
│       └── 000310.png
├── busi
│   ├── busi_train_test_names.pkl
│   ├── images
│   │   ├── benign_100.png
│   │   ...
│   │   └── normal_9.png
│   └── masks
│       ├── benign_100.png
│       ...
│       └── normal_9.png
├── glas
│   ├── glas_train_test_names.pkl
│   ├── images
│   │   ├── testA_10.png
│   │   ...
│   │   └── train_9.png
│   └── masks
│       ├── testA_10.png
│       ...
│       └── train_9.png
├── ham10000
│   ├── ham10000_train_test_names.pkl
│   ├── images
│   │   ├── ISIC_0024306.png
│   │   ...
│   │   └── ISIC_0034320.png
│   └── masks
│       ├── ISIC_0024306.png
│       ...
│       └── ISIC_0034320.png
├── kvasir-instrument
│   ├── kvasir_train_test_names.pkl
│   ├── images
│   │   ├── ckcu8ty6z00003b5yzfaezbs5.png
│   │   ...
│   │   └── ckd4kgex0000h3b5yhpwjd11l.png
│   └── masks
│       ├── ckcu8ty6z00003b5yzfaezbs5.png
│       ...
│       └── ckd4kgex0000h3b5yhpwjd11l.png
└── show_pkl.py

We provide the preprocessed BUSI and Kvasir-Instrument via this link, please download the dataset file and unzip it into the Dataset folder. For other datasets, please download them via the dataset websites and organize as the same structure. The .pkl file stores the train and test spilt for each dataset, you can run show_pkl.py to show the content for each pkl file.

Download the pre-trained SD VAE weight

We provide the pre-trained SD VAE weight via this link, please download and place it in SD-VAE-weights folder.

Model Inference

We provide model weights for five dataset at ckpt folder. Once all datasets are preprocessed, please run the following command for model inference:

sh valid.sh

The DSC, IOU, and predicted masks will be automatically saved.

Model training

Please run the following command for model training:

sh train.sh

To change hyper-parameters (batchsize, learning rate, training epochs, etc.), please refer to the dataset training yaml file (e.g. BUSI training yaml). We train GMS on an NVIDIA A100 40G GPU with the batchsize set to 8. If you encounter the OOM problem, please try to decrease the batchsize.

Trainable Parameters

Since we freeze the SD VAE and only train the latent mapping model, the trainable parameters of GMS are much fewer than other medical image segmentation models. Note that EGE-UNet is a well-designed lightweight model for medical image segmentation.

trainable_parameters

Quantitative Segmentation Results

Quantitative performance on different datasets compared to other models. Best and second-best performances are bold and underlined, respectively. $^\dagger$ indicates fewer trainable parameters than GMS.

quantitative_results

Visualization Results

qualitative_results

Cross-domain Segmentation Results

We were surprised to find that the intrinsic domain generalization ability of GMS is much powerful than other segmentation models, even better than some methods (MixStyle and DSU) that were designed for the domain generalization problem.

cross_domain_quantitative_results

Citation

If you use this code for your research, please consider citing our paper.

@article{huo2024generative,
  title={Generative Medical Segmentation},
  author={Huo, Jiayu and Ouyang, Xi and Ourselin, S{\'e}bastien and Sparks, Rachel},
  journal={arXiv preprint arXiv:2403.18198},
  year={2024}
}

Acknowledgments

Thanks for the following code repositories: Stabled Diffusion and GSS