explainingai-code / StableDiffusion-PyTorch

This repo implements a Stable Diffusion model in PyTorch with all the essential components.
MIT License
147 stars 31 forks source link

Missing Scaling factor #33

Open sunly92 opened 4 days ago

sunly92 commented 4 days ago

Hi,

This is really a very good repo for learning stable diffusion from scratch. However, I found the missing scaling factor that should have been applied to latent $z$ before U-Net. It was said to keep the variance of the latent onto a unit circle which could facilitate training. A detailed discussion can be found at: https://github.com/huggingface/diffusers/issues/437

Cheers, Liyan

explainingai-code commented 3 days ago

Hello @sunly92 ,

Thank you :) You are right regarding the scaling factor not present, but this scaling is only used by the authors for VAE and not VQVAE. You can see scale_by_std parameter defined in this config but not here From paper - "Note that the VQ-regularized space has a variance close to 1, such that it does not have to be rescaled." Since I only provided the code for VQVAE training in this repo, so this scaling was not required.

Interestingly, for both the datasets when I computed latent std, I ended up with a std close to 1(not just for vqvae but also for vae), so latent scaling would not have made any difference even for vae. Which is why even in other repos where I have also trained a vae, I decided to skip it rather than adding the logic to compute std for every batch item and do the scaling.

Hope this helps.