explainingai-code / StableDiffusion-PyTorch

This repo implements a Stable Diffusion model in PyTorch with all the essential components.
77 stars 15 forks source link

Unexpected output after sampling using Conditional LDM #9

Closed randomaccess2023 closed 3 months ago

randomaccess2023 commented 3 months ago

I am getting a bunch of noodle-like waves after sampling for conditional LDM instead of proper digits. The unconditional LDM works fine. I am using the MNIST dataset that the Torchvision library has (torchvision.datasets.MNIST).

Can you tell what could be wrong in this scenario?

I have attached x0_0.png outputs for both Unconditional_LDM and Conditional_LDM. Unconditional_LDM_x0_0 Conditional_LDM_x0_0

explainingai-code commented 3 months ago

Hi @randomaccess2023 , Can you share the config that you have and also the sampling script (if you have made any changes in that) ? and also shape of xt you get here This seems like an issue of the model being asked to generate images of different size that it was trained with. So just want to confirm if that is somehow not the case.

explainingai-code commented 3 months ago

Also I think the model is getting images scaled from 0-1 rather than -1 to 1. If thats the case then after loading the mnist images make sure to do the scaling as well(similar to https://github.com/explainingai-code/StableDiffusion-PyTorch/blob/main/dataset/mnist_dataset.py#L91 )

randomaccess2023 commented 3 months ago

Hi @randomaccess2023 , Can you share the config that you have and also the sampling script (if you have made any changes in that) ? and also shape of xt you get here This seems like an issue of the model being asked to generate images of different size that it was trained with. So just want to confirm if that is somehow not the case.

@explainingai-code Yes, you are spot on. I selected BCHW as (25, 3, 28, 28) instead of (25, 3, 7, 7). Changed it.

randomaccess2023 commented 3 months ago

Also I think the model is getting images scaled from 0-1 rather than -1 to 1. If thats the case then after loading the mnist images make sure to do the scaling as well(similar to https://github.com/explainingai-code/StableDiffusion-PyTorch/blob/main/dataset/mnist_dataset.py#L91 )

@explainingai-code Yes, I scaled the images within a range of 0 to 1 rather than -1 to 1. I always scale from 0 to 1. I changed it and it resulted in a higher codebook loss for VQ-VAE after training for 10 epochs. Now, I am training the LDM which will take more time for 100 epochs. Also, I fed the labels directly to the embedding table and then added them to the position encoding layer. It worked for Conditional DDPM and for that reason, I didn't use the one-hot method.

explainingai-code commented 3 months ago

Yeah that way of conditioning is also fine. Btw If you are going to use the 0-1 scaling then do remove this line . Its needed only if the output images are -1 to 1

randomaccess2023 commented 3 months ago

@explainingai-code Thanks a lot for your help. I got good outputs after scaling from -1 to 1.

Unconditional_LDM: Unconditional_LDM_x0_0 Conditional_LDM: Conditional_LDM_x0_0