JuliaWolleb / Diffusion-based-Segmentation

This is the official Pytorch implementation of the paper "Diffusion Models for Implicit Image Segmentation Ensembles".
MIT License
280 stars 38 forks source link

Using DDIM to decrease sampling time #32

Closed JesseWiers closed 1 year ago

JesseWiers commented 1 year ago

Hello again. I was wondering whether it is normal behaviour to get an error trying to sample using ddiim with --use_ddim True. Do you have to change anything in the training process in order to able to use ddim thereafter for sampling. The specific error i am getting is a size mismatch:

RuntimeError: Given groups=1, weight of size [256, 2, 3, 3], expected input[1, 3, 256, 256] to have 2 channels, but got 3 channels instead

Does the Unet architecture need to be changed in order to be able to use DDIM?

JuliaWolleb commented 1 year ago

Hi no, you should be able to--use_ddim True for sampling. There seems to be an error in your input image. It should be of dimension [1, 3, 256, 256]. Can you double-check your input dimensions?

JesseWiers commented 1 year ago

Thank you for your quick reply. I am using the code on another dataset composed of organoid culture images. These images consist of 1 channel and thus for me, the dimension should be [1, 2, 256, 256]. The problem I had occurred because of the following line in gaussian.diffusion.py, ddim_sample_loop_progressive():

if img.shape != (1, 5, 224, 224): img = torch.cat((orghigh,img), dim=1).float()

Which added a second noise layer to img. Changing it to if img.shape != (1, 2, 224, 224): resolved the issue. ddim_sample_loop_known() does however output 2 samples (channels) instead of 1. This is because eps in ddim_sample takes the shape of x_twhich also contains the org MRI (in my case not MRI) channels. The same holds for alpha_bar and alpha_bar_prev in ddim_sample.

I changed the code so that the calculations for eps, alpha_bar and alpha_bar_prev only use the last channel (the current sample). The final samples all have pixel values close to 0, resulting in no meaningful segmentation maps. The segmentation maps without using DDIM are however perfectly fine. Do you perhaps have any clue why this is the case? Should I perhaps change the calculations of eps, alpha_bar and alpha_bar_prev in another way?

xupinggl commented 1 year ago

I have encountered a new problem with DDIM on windows. I used the brats dataset given by the author and didn't change anything Can you help me to solve

1684917545938

JuliaWolleb commented 1 year ago

@JesseWiers Sorry for the late reply.

First of all yes, if your input image has one channel, you will need to change the line if img.shape != (1, 5, 224, 224): img = torch.cat((orghigh,img), dim=1).float() to if img.shape != (1, 2, 224, 224): img = torch.cat((orghigh,img), dim=1).float().

As for sampling, you would need to stick to p_sample_loop_known() if you also want to generate uncertainty maps. There the output dimension should be 1? If you want to switch to the ddim sampling scheme, you need to double check that the output dimension of p_sample_loop_known() and ddim_sample_loop_known() are the same.

xupinggl commented 1 year ago

    您好,您的邮件我已收到,我会尽快回复的,谢谢!

JuliaWolleb commented 1 year ago

@xupinggl this is a cuda error on your machine, can you try to run the code on another gpu?

JesseWiers commented 1 year ago

Great, thank you. Moreover, is sampling supposed to work with a batch size higher than 1 (i.e. performing multiple diffusion passes at once)?

JuliaWolleb commented 1 year ago

I never tried, but nothing speaks against that, if it fits on your GPU.

xupinggl commented 1 year ago

    您好,您的邮件我已收到,我会尽快回复的,谢谢!