icon-lab / SelfRDB

Official PyTorch implementation of SelfRDB, a diffusion bridge model for multi-modal medical image synthesis
MIT License
30 stars 1 forks source link
deep-learning diffusion-bridge diffusion-models image-synthesis image-to-image-translation medical-imaging neural-networks python pytorch schrodinger-bridge

SelfRDB
Self-Consistent Recursive Diffusion Bridge for Medical Image Translation

Fuat Arslan1,2·Bilal Kabas1,2·Onat Dalmaz3·Muzaffer Ozbey4·Tolga Çukur1,2 1Bilkent University   2UMRAM   3Stanford University   4University of Illinois Urbana-Champaign

[arXiv]

Official PyTorch implementation of SelfRDB, a novel diffusion bridge model for multi-modal medical image synthesis that employs a novel forward process with soft-prior, and self-consistent recursion in reverse process. Our novel noise scheduling with monotonically increasing variance towards the end-point, i.e. soft-prior, boosts generalization performance and facilitates information transfer between the two modalities. To further enhance sampling accuracy in each reverse step, SelfRDB utilizes a novel sampling procedure where the network recursively generates a transient-estimate of the target image until convergence onto a self-consistent solution.

architecture

⚙️ Installation

This repository has been developed and tested with CUDA 11.7 and Python 3.8. Below commands create a conda environment with required packages. Make sure conda is installed.

conda env create --file requirements.yaml
conda activate selfrdb

🗂️ Prepare dataset

The default data set class NumpyDataset requires the following folder structure to organize the data set. Modalities (T1, T2, etc.) are separated by folders, splits (train, val, test) are organized as subfolders which include 2D images: slice_0.npy, slice_1.npy, ... To use your custom data set class, set dataset_class to your own implementation in dataset.py by inheriting from the BaseDataset class.

Images should be scaled to have pixel values in the range [0,1].

<dataset>/
├── <modality_a>/
│   ├── train/
│   │   ├── slice_0.npy
│   │   ├── slice_1.npy
│   │   └── ...
│   ├── test/
│   │   ├── slice_0.npy
│   │   └── ...
│   └── val/
│       ├── slice_0.npy
│       └── ...
├── <modality_b>/
│   ├── train/
│   ├── test/
│   └── val/
├── ...

🏃 Training

Run the following command to start/resume training. Model checkpoints are saved under logs/$EXP_NAME/version_x/checkpoints directory, and sample validation images are saved under logs/$EXP_NAME/version_x/val_samples. The script supports both single and multi-GPU training. By default, it runs on a single GPU. To enable multi-GPU training, set --trainer.devices argument to the list of devices, e.g. 0,1,2,3.

python main.py fit \
    --config config.yaml \
    --trainer.logger.name $EXP_NAME \
    --data.dataset_dir $DATA_DIR \
    --data.source_modality $SOURCE \
    --data.target_modality $TARGET \
    --data.train_batch_size $BS_TRAIN \
    --data.val_batch_size $BS_VAL \
    [--trainer.max_epoch $N_EPOCHS] \
    [--ckpt_path $CKPT_PATH] \
    [--trainer.devices $DEVICES]

Argument descriptions

Argument Description
--config Config file path.
--trainer.logger.name Experiment name.
--data.dataset_dir Data set directory.
--data.source_modality Source modality, e.g. 'T1', 'T2', 'PD'. Should match the folder name for that modality.
--data.train_batch_size Train set batch size.
--data.val_batch_size Validation set batch size.
--trainer.max_epoch [Optional] Number of training epochs (default: 50).
--ckpt_path [Optional] Model checkpoint path to resume training.
--trainer.devices [Optional] Device or list of devices. For multi-GPU set to the list of device ids, e.g 0,1,2,3 (default: [0]).

🧪 Testing

Run the following command to start testing. The predicted images are saved under logs/$EXP_NAME/version_x/test_samples directory. By default, the script runs on a single GPU. To enable multi-GPU testing, set --trainer.devices argument to the list of devices, e.g. 0,1,2,3.

python main.py test \
    --config config.yaml \
    --data.dataset_dir $DATA_DIR \
    --data.source_modality $SOURCE \
    --data.target_modality $TARGET \
    --data.test_batch_size $BS_TEST \
    --ckpt_path $CKPT_PATH

Argument descriptions

Some arguments are common to both training and testing and are not listed here. For details on those arguments, please refer to the training section.

Argument Description
--data.test_batch_size Test set batch size.
--ckpt_path Model checkpoint path.

🦁 Model Zoo

Refer to the testing section above to perform inference with the checkpoints. PSNR (dB) and SSIM (%) are listed as mean ± std across the test set.

Dataset Task PSNR SSIM Checkpoint
IXI T2→T1 31.63 ± 1.53 95.64 ± 1.12 Link
IXI T1→T2 31.28 ± 1.56 95.03 ± 1.27 Link
IXI PD→T1 31.23 ± 1.22 95.64 ± 0.99 Link
IXI T1→PD 32.17 ± 1.57 95.15 ± 0.99 Link
BRATS T2→T1 28.85 ± 1.48 93.70 ± 1.87 Link
BRATS T1→T2 27.58 ± 1.88 92.99 ± 2.44 Link
BRATS FLAIR→T2 26.85 ± 1.75 91.66 ± 2.72 Link
BRATS T2→FLAIR 27.98 ± 1.80 90.01 ± 2.70 Link
CT T2→CT 29.18 ± 2.18 93.28 ± 1.99 Link
CT T1→CT 27.55 ± 3.32 92.29 ± 6.32 Link

✒️ Citation

You are encouraged to modify/distribute this code. However, please acknowledge this code and cite the paper appropriately.

@article{arslan2024selfconsistent,
  title={Self-Consistent Recursive Diffusion Bridge for Medical Image Translation}, 
  author={Fuat Arslan and Bilal Kabas and Onat Dalmaz and Muzaffer Ozbey and Tolga Çukur},
  year={2024},
  journal={arXiv:2405.06789}
}

Copyright © 2024, ICON Lab.