JuliaWolleb / Diffusion-based-Segmentation

This is the official Pytorch implementation of the paper "Diffusion Models for Implicit Image Segmentation Ensembles".
MIT License
280 stars 38 forks source link

How to modify the training and sampling process to generate enhanced image, not mask? #53

Open FJGEODEV opened 1 year ago

FJGEODEV commented 1 year ago

Hi thanks for the great work. I am wondering if this repo can be used to generate enhanced images, e.g. super-resolution images.

I modified your codes but still couldn't get expected results, hope you can give me some hints. Here are what I did:

Training data: Low-Resolution images (considered as MRI image) & High-Resolution images (considered as Masks).

Training following the repo but commented out label=torch.where(label > 0, 1, 0).float() .

The training will learn the epsilon between LR and HR images (just like MRI and its masks).

Then I run inference part with data folder only contain LR image, trying to generate HR image from the pretrained model.

Is there anything I missed? I tried to find out any codes where setting "binary mask", but no luck.

Not sure if this workflow is suitable for this kind of task.

Really appreciated your help!

JuliaWolleb commented 1 year ago

Hi Yes, the same repo can be used for super-resolution. Check out the Palette paper, where they did something very similar. What you need to do is that the training will learn the epsilon between the true and the predicted HR images. The LR images are just the condition (i.e., the MRI in my approach). Add the LR images as condition through concatenation to the noisy HR image x_t. So input to your model is 2 channels: the noisy HR image xt, and the non-noisy LR image. The output of the model is 1 channel: the sligthly denoised HR image x(t-1). During inference (for generation of the HR image), you start from random noise, and add the LR image through concatenation in every timestep. This is just like we concatenate the MR image in every time step to generate the corresponding mask. I would expect this LR-to-HR to work fine for this type of task. Let me know if you have any other questions.

FJGEODEV commented 1 year ago

Thank you so much, will try.

yug125lk commented 1 year ago

Hi, did you try it? could you help me please, my custom dataset (grayscale) contains only two images (input and target). I changed dataloader.py and gaussian_diffusion.py and also these two lines in script_util.py (the input and the output). I set learn_sigma False. return UNetModel( image_size=image_size, in_channels=2, model_channels=num_channels, out_channels=1,#(3 if not learn_sigma else 6)

when I tested it , I got only noise in Sampled Output image

noise