mazurowski-lab / segmentation-guided-diffusion

[MICCAI 2024] Easy diffusion models (optionally with segmentation guidance) for medical images and beyond.
https://arxiv.org/abs/2402.05210
Other
115 stars 6 forks source link

Samples of segmentation-guided ddpm/ddim NOT consistent with their respective segmentation masks during training/inference (cardiac MRI) #17

Open varshak97 opened 2 days ago

varshak97 commented 2 days ago

Hello! I want to begin by expressing my appreciation for this work; particularly the straightforward mask-ablated training strategy. Thank you to you and your team for sharing the code!

I am facing an issue when training the segmentation-guided DDPM/DDIM model with cardiac MRI data (using the public ACDC dataset). I have observed that the generated samples are not consistent with their corresponding segmentation masks, during training/inference.

Details: Inconsistent Results: The reconstructed/generated samples do not align with the input segmentation masks, i.e. the spatial features outlined by the masks are not correctly reflected in the generated samples. Noisy Samples: There is a noticeable presence of noisy samples, although increasing the number of epochs has somewhat resolved this issue. After training for 1000 epochs, there are still occasional noisy outputs, but my primary concern is the inconsistency between the generated samples and segmentation masks.

I have double-checked my data preparation step, and I don't believe the issue lies there. This is command I used for training the segmentation-guided diffusion model:

python3 segmentation-guided-diffusion/main.py \ --mode "train" \ --img_size 256 \ --num_img_channels 1 \ --dataset "cardiac_mri" \ --img_dir "/home/varshak/scratch/acdc/segdiff/data_dir/" \ --seg_dir "/home/varshak/scratch/acdc/segdiff/mask_dir/" \ --model_type "DDPM" \ --segmentation_guided \ --num_segmentation_classes 4 \ --train_batch_size 16 \ --eval_batch_size 8 \ --use_ablated_segmentations \ --eval_noshuffle_dataloader \ --num_epochs 1000

Following are sample outputs from epoch 999 for more context. Also, this issue persists throughout the entire training process, not just at the beginning or the end. I would appreciate any insights or suggestions on how to resolve this issue! Thank you so much for your time :)

Results/samples of ddpm_cardiac-mri_256_segguided_abalated @ Epoch 999:

0999_cond_seg_all.png 0999_cond_seg_all

0999_orig.png 0999_orig

0999.png 0999

varshak97 commented 1 day ago

Hi, I'm attaching a drive link to a Google Doc with additional samples (training and inference) just for further clarification and review regarding this issue.

Link: https://docs.google.com/document/d/17VbNJ9LZJPoskOSP2kckq5vs9L7E2h6X-Yb-iZVUzFw/edit?usp=sharing