forever208 / DDPM-IP

[ICML 2023] official implementation for "Input Perturbation Reduces Exposure Bias in Diffusion Models"
MIT License
96 stars 9 forks source link

Input perturbation by increasing the noise strength #27

Open pengzhangzhi opened 5 months ago

pengzhangzhi commented 5 months ago

Hi @forever208 , Great work! I like your observation of the inconsistency between training and sampling and propose a simple input perturbation to mitigate/align it. A question about your implementation code.

new_noise = noise + gamma * th.randn_like(noise)  # gamma=0.1 

I understand this equation as increasing the noise strength because the new noise is essentially two Gaussian noises added with weight 1 and gamma. Thus the resultant new_noise is a gaussian noise scaled by (1+gemma)

If so, can I understand the input perturbation as interpolating with a larger noise?

Thanks, Zhangzhi

forever208 commented 5 months ago

hi @pengzhangzhi, thanks for your comments and insightful observation, yes, you can understand input perturbation as interpolating with a larger noise, but the training target, in this case, is a scaled epsilon.

We actually derived an equivalent loss for input perturbation (IP) below 2024-04-19_20-30 Treating $\bar{\beta_t}\sqrt{1+\gamma^2}$ as a whole, this term becomes the new noise schedule and it is larger than the original noise schedule $\bar{\beta_t}$

from which, we can interpret IP as training the reverse diffusion using a larger noise schedule $\bar{\beta_t}\sqrt{1+\gamma^2}$, but the target is $\frac{\epsilon}{\sqrt{1+\gamma^2}}$ which differs from $\epsilon$

pengzhangzhi commented 5 months ago

Thanks!! I like this idea, super neat! Anywhere I can read about the loss from the code? I am interested in applying this technique to other diffusion models so I want to make sure that I follow all the details at least from the implementation/code side. Thanks soo much for the nice response:)

forever208 commented 5 months ago

@pengzhangzhi do you mean the equivalent loss I showed above? It is not implemented in the released code, but I personally tried it, and the results are equivalent to input perturbation.