Open gmongaras opened 1 year ago
I think they follow the original DDPM implementation. By predicting x_0 first, they can apply value clipping to x_0. This is a trick not mentioned in the paper to improve sampling quality. Then from x0 we can do posterior sampling to get x{t-1}.
Without value clipping, the two are the same (either go directly from xt to x{t-1} using Eq 13, or go x0 first and then x{t-1}). See the discussions in https://github.com/hojonathanho/diffusion/issues/5
Oh yeah, that makes sense! As I've learned more about diffusion models, it looks like predicting x_0 produces better results as one can skip steps like in DDIM.
I was looking through the code to see how the paper was implemented, but I ran into an issue when looking at the part of the paper measuring the KL loss between two Gaussians:
Specifically, the Loss at time t-1 is the KL loss between the predicted gaussian and real gaussian at time t-1. The predicted gaussian is defined as follows:
And the real gaussian is defined as follows:
The formulation of the loss function makes sense to me, but when I look at the code, it looks like the authors are having the model predict mu_tilde (eq 11) as opposed to mu (eq 13). I'm looking at the following function in the code: https://github.com/openai/improved-diffusion/blob/783b6740edb79fdb7d063250db2c51cc9545dcd1/improved_diffusion/gaussian_diffusion.py#L232
In this function, the mean is calculated from epsilon by first calculating the prediction for x_0, then calculating the mean at time t.
To predict x_0, the following function is used: https://github.com/openai/improved-diffusion/blob/783b6740edb79fdb7d063250db2c51cc9545dcd1/improved_diffusion/gaussian_diffusion.py#L328
But, this function looks to be the formulation for the mean function (eq 13)
I have a couple of questions regarding the implementation:
Thanks for the help!