pfriedri / wdm-3d

PyTorch implementation for "WDM: 3D Wavelet Diffusion Models for High-Resolution Medical Image Synthesis" (DGM4MICCAI 2024)
https://pfriedri.github.io/wdm-3d-io
MIT License
50 stars 5 forks source link
3d diffusion-models image-generation wavelet-transform

WDM: 3D Wavelet Diffusion Models for High-Resolution Medical Image Synthesis

License: MIT Static Badge arXiv

This is the official PyTorch implementation of the paper WDM: 3D Wavelet Diffusion Models for High-Resolution Medical Image Synthesis by Paul Friedrich, Julia Wolleb, Florentin Bieder, Alicia Durrer and Philippe C. Cattin.

If you find our work useful, please consider to :star: star this repository and :memo: cite our paper:

@article{friedrich2024wdm,
         title={WDM: 3D Wavelet Diffusion Models for High-Resolution Medical Image Synthesis},
         author={Paul Friedrich and Julia Wolleb and Florentin Bieder and Alicia Durrer and Philippe C. Cattin},
         year={2024},
         journal={arXiv preprint arXiv:2402.19043}}

Paper Abstract

Due to the three-dimensional nature of CT- or MR-scans, generative modeling of medical images is a particularly challenging task. Existing approaches mostly apply patch-wise, slice-wise, or cascaded generation techniques to fit the high-dimensional data into the limited GPU memory. However, these approaches may introduce artifacts and potentially restrict the model's applicability for certain downstream tasks. This work presents WDM, a wavelet-based medical image synthesis framework that applies a diffusion model on wavelet decomposed images. The presented approach is a simple yet effective way of scaling diffusion models to high resolutions and can be trained on a single 40 GB GPU. Experimental results on BraTS and LIDC-IDRI unconditional image generation at a resolution of 128 x 128 x 128 show state-of-the-art image fidelity (FID) and sample diversity (MS-SSIM) scores compared to GANs, Diffusion Models, and Latent Diffusion Models. Our proposed method is the only one capable of generating high-quality images at a resolution of 256 x 256 x 256.

Dependencies

We recommend using a conda environment to install the required dependencies. You can create and activate such an environment called wdm by running the following commands:

mamba env create -f environment.yml
mamba activate wdm

Training & Sampling

For training a new model or sampling from an already trained one, you can simply adapt and use the script run.sh. All relevant hyperparameters for reproducing our results are automatically set when using the correct MODEL in the general settings. For executing the script, simply use the following command:

bash run.sh

Supported settings (set in run.sh file):

MODE: 'training', 'sampling'

MODEL: 'ours_unet_128', 'ours_unet_256', 'ours_wnet_128', 'ours_wnet_256'

DATASET: 'brats', 'lidc-idri'

Pretrained Models

We released pretrained models on HuggingFace.

Currently available models:

Data

To ensure good reproducibility, we trained and evaluated our network on two publicly available datasets:

The provided code works for the following data structure (you might need to adapt the DATA_DIR variable in run.sh):

data
└───BRATS
    └───BraTS-GLI-00000-000
        └───BraTS-GLI-00000-000-seg.nii.gz
        └───BraTS-GLI-00000-000-t1c.nii.gz
        └───BraTS-GLI-00000-000-t1n.nii.gz
        └───BraTS-GLI-00000-000-t2f.nii.gz
        └───BraTS-GLI-00000-000-t2w.nii.gz  
    └───BraTS-GLI-00001-000
    └───BraTS-GLI-00002-000
    ...

└───LIDC-IDRI
    └───LIDC-IDRI-0001
      └───preprocessed.nii.gz
    └───LIDC-IDRI-0002
    └───LIDC-IDRI-0003
    ...

We provide a script for preprocessing LIDC-IDRI. Simply run the following command with the correct path to the downloaded DICOM files DICOM_PATH and the directory you want to store the processed nifti files NIFTI_PATH:

python utils/preproc_lidc-idri.py --dicom_dir DICOM_PATH --nifti_dir NIFTI_PATH

Evaluation

As our code for evaluating the model performance has slightly different dependencies, we provide a second .yml file to set up the evaluation environment. Simply use the following command to create and activate the new environment:

mamba env create -f eval/eval_environment.yml
mamba activate eval

FID

For computing the FID score, you need to specify the following variables and use them in the command below:

python eval/ms_ssim.py --dataset DATASET --img_size IMG_SIZE --sample_dir SAMPLE_DIR

Implementation Details for Comparing Methods

All experiments were performed on a system with an AMD Epyc 7742 CPU and a NVIDIA A100 (40GB) GPU.

TODOs

We plan to add further functionality to our framework:

Acknowledgements

Our code is based on / inspired by the following repositories:

For computing FID scores we use a pretrained model (resnet_50_23dataset.pth) from:

Thanks for making these projects open-source.