This is the official PyTorch implementation of the paper WDM: 3D Wavelet Diffusion Models for High-Resolution Medical Image Synthesis by Paul Friedrich, Julia Wolleb, Florentin Bieder, Alicia Durrer and Philippe C. Cattin.
If you find our work useful, please consider to :star: star this repository and :memo: cite our paper:
@article{friedrich2024wdm,
title={WDM: 3D Wavelet Diffusion Models for High-Resolution Medical Image Synthesis},
author={Paul Friedrich and Julia Wolleb and Florentin Bieder and Alicia Durrer and Philippe C. Cattin},
year={2024},
journal={arXiv preprint arXiv:2402.19043}}
Due to the three-dimensional nature of CT- or MR-scans, generative modeling of medical images is a particularly challenging task. Existing approaches mostly apply patch-wise, slice-wise, or cascaded generation techniques to fit the high-dimensional data into the limited GPU memory. However, these approaches may introduce artifacts and potentially restrict the model's applicability for certain downstream tasks. This work presents WDM, a wavelet-based medical image synthesis framework that applies a diffusion model on wavelet decomposed images. The presented approach is a simple yet effective way of scaling diffusion models to high resolutions and can be trained on a single 40 GB GPU. Experimental results on BraTS and LIDC-IDRI unconditional image generation at a resolution of 128 x 128 x 128 show state-of-the-art image fidelity (FID) and sample diversity (MS-SSIM) scores compared to GANs, Diffusion Models, and Latent Diffusion Models. Our proposed method is the only one capable of generating high-quality images at a resolution of 256 x 256 x 256.
We recommend using a conda environment to install the required dependencies.
You can create and activate such an environment called wdm
by running the following commands:
mamba env create -f environment.yml
mamba activate wdm
For training a new model or sampling from an already trained one, you can simply adapt and use the script run.sh
. All relevant hyperparameters for reproducing our results are automatically set when using the correct MODEL
in the general settings.
For executing the script, simply use the following command:
bash run.sh
Supported settings (set in run.sh
file):
MODE: 'training'
, 'sampling'
MODEL: 'ours_unet_128'
, 'ours_unet_256'
, 'ours_wnet_128'
, 'ours_wnet_256'
DATASET: 'brats'
, 'lidc-idri'
We released pretrained models on HuggingFace.
Currently available models:
To ensure good reproducibility, we trained and evaluated our network on two publicly available datasets:
BRATS 2023: Adult Glioma, a dataset containing routine clinically-acquired, multi-site multiparametric magnetic resonance imaging (MRI) scans of brain tumor patients. We just used the T1-weighted images for training. The data is available here.
LIDC-IDRI, a dataset containing multi-site, thoracic computed tomography (CT) scans of lung cancer patients. The data is available here.
The provided code works for the following data structure (you might need to adapt the DATA_DIR
variable in run.sh
):
data
└───BRATS
└───BraTS-GLI-00000-000
└───BraTS-GLI-00000-000-seg.nii.gz
└───BraTS-GLI-00000-000-t1c.nii.gz
└───BraTS-GLI-00000-000-t1n.nii.gz
└───BraTS-GLI-00000-000-t2f.nii.gz
└───BraTS-GLI-00000-000-t2w.nii.gz
└───BraTS-GLI-00001-000
└───BraTS-GLI-00002-000
...
└───LIDC-IDRI
└───LIDC-IDRI-0001
└───preprocessed.nii.gz
└───LIDC-IDRI-0002
└───LIDC-IDRI-0003
...
We provide a script for preprocessing LIDC-IDRI. Simply run the following command with the correct path to the downloaded DICOM files DICOM_PATH
and the directory you want to store the processed nifti files NIFTI_PATH
:
python utils/preproc_lidc-idri.py --dicom_dir DICOM_PATH --nifti_dir NIFTI_PATH
As our code for evaluating the model performance has slightly different dependencies, we provide a second .yml file to set up the evaluation environment. Simply use the following command to create and activate the new environment:
mamba env create -f eval/eval_environment.yml
mamba activate eval
For computing the FID score, you need to specify the following variables and use them in the command below:
brats
or lidc-idri
128
or 256
./eval/pretrained/resnet_50_23dataset.pt
./eval/activations/
0
python eval/fid.py --dataset DATASET --img_size IMG_SIZE --data_root_real REAL_DATA_DIR --data_root_fake FAKE_DATA_DIR --pretrain_path PATH_TO_FEATURE_EXTRACTOR --path_to_activations PATH_TO_ACTIVATIONS --gpu_id GPU_ID
For computing the mean MS-SSIM, you need to specify the following variables and use them in the command below:
brats
or lidc-idri
128
or 256
python eval/ms_ssim.py --dataset DATASET --img_size IMG_SIZE --sample_dir SAMPLE_DIR
All experiments were performed on a system with an AMD Epyc 7742 CPU and a NVIDIA A100 (40GB) GPU.
We plan to add further functionality to our framework:
Our code is based on / inspired by the following repositories:
For computing FID scores we use a pretrained model (resnet_50_23dataset.pth
) from:
Thanks for making these projects open-source.