explainingai-code / StableDiffusion-PyTorch

This repo implements a Stable Diffusion model in PyTorch with all the essential components.
122 stars 25 forks source link

Why your model generated mnist images are noises? #4

Closed CatLoves closed 6 months ago

CatLoves commented 6 months ago

Dear explainingai-code: Your codebase for ddpm is so detailed and helpful that I would like to thank you very much for your great work ! I have downloaded this codebase and follow your instructions carefully, luckly I got a good VQVAE result as follows: image

So I continue to train unconditional and class conditional ddpm models. 
I run python tools/train_ddpm_vqvae.py to train unconditional ddpm model and it seems to have converged as follows:

image The loss decreases from 0.2833 to 0.0886. Then I run python tools/sample_ddpm_vqvae.py to check model output, I get mnist/samples/x0_0.png to x0_999.png, all seems like PURE NOISE as follows: image

 Simiarly I trained and tested mnist class conditional ddpm and all results seems like PURE NOISE as follows:

image

 My code structure as follows:

image I am sure that I just follow your instructions and DO NOT modify model related code, but the results seems weired. Appreciated if you could give some help.

Sincere, CatLoves

explainingai-code commented 6 months ago

Hello @CatLoves, Thank you for the appreciation. Can you please share both the x0_0.png and x0_999.png file. The x0_0.png will be a decoded image(this is the actual generated image) and x0_999.png will be a pure noise(for unconditional) but for conditional even x0_999.png(which is the very initial latent image prediction) will still be something that looks like digits(like the ones you uploaded).

CatLoves commented 6 months ago

@explainingai-code Sure. The x0_0.png under mnist/samples is as follows: image It seems very nice. So for this problem the reason is that I didn't figure out that only the last step's output would be decoded, x0_1.png to x0_999.png would be decoded latent rather than image itself. Sorry for this misunderstanding. For class conditional ddpm on mnist, the x0_0 is as follows: image It also makes sense. Thank you again for your quick and detailed response. I am looking forward to your further great work. Have a nice day !

explainingai-code commented 6 months ago

Thats great! I understand how the image numbering could be confusing. Have added this in the readme output section now. https://github.com/explainingai-code/StableDiffusion-PyTorch/tree/main?tab=readme-ov-file#output

Thank You