microsoft / VQ-Diffusion

Official implementation of VQ-Diffusion
MIT License
863 stars 63 forks source link

Calculation of q_posterior? #20

Open PanXiebit opened 1 year ago

PanXiebit commented 1 year ago

https://github.com/microsoft/VQ-Diffusion/blob/16dc744405e59ed1833513ebb1db87d6263d38be/image_synthesis/modeling/transformers/diffusion_transformer.py#L244 q_pred(log_x_start, t) is the forward computation for sampling x_t given x_0 and t, but in this line of posterior, the given condition is log_x_t and t?

https://github.com/microsoft/VQ-Diffusion/blob/16dc744405e59ed1833513ebb1db87d6263d38be/image_synthesis/modeling/transformers/diffusion_transformer.py#L253

The same confusion also appears in this line. q_pred_one_timestep(self, log_x_t, t) is also the forward computation for sampling xt given x{t-1}, but in this line of posterior, the given condition is log_x_t and t?

trestad commented 1 year ago

same confusion

yingyukexiansheng commented 1 year ago

same confusion

trestad commented 1 year ago

Hello everyone. I contacted the author Shuyang Gu and got a patient reply. Appreciate the help of Gu! The following answer mainly comes from the author's reply and I add some personal understanding:

Here is the explanation of lines 244-250: In the transition matrix Qt, the i-th row represents probabilities of getting different xts given a x0 which is the i-th token; The i-th column represents probabilities of getting the current xt (the i-th token) from different x0s. When calculating the posterior, our goal is to calculate the probability of different x0s transferring to the known xt (log_x_t). Therefore, we should calculate the columns corresponding to each element in log_x_t. However, the trivial solution may be a little complicated. The author used a trick: (1) Obtain Qt by q_pred(…, t). (2) Get corresponding rows according to items in log_x_t, i.e., q_pred(log_x_t, t) (3) When [mask] is not considered, because of the symmetry of Qt, the rows we got is equivalent to the columns that represent probabilities of getting the current xt from different x0s. (4) If [mask] is considered, we have to replace the last value in rows with ct_cumprod (It is clear referring to Eq.7 in the paper.)

Lines 253-258 are for the same reason.

Here is the explanation of lines 260-267: The goal of reparameterization is to calculate: p(x_t-1|xt) = sum( q(x_t-1|xt,x0) p(x0|xt) ) (Eq.11 in the paper), where q(x_t-1|xt,x0) = q(xt|x_t-1,x0) q(x_t-1|x0) / q(xt|x0) (*).

In (), because of property of Markov chain, q(xt|x_t-1,x0) = q(xt|x_t-1) = log_qt_one_timestep, and p(x0|xt) q(x_t-1|x0) / q(xt|x0) = q_pred(q, t-1). Because the sum of p(x0|xt) / q(xt|x0) is not 1, lines 262-265 normalize it and it was renormalized in line 266.

yingyukexiansheng commented 1 year ago

Thanks so much for your patience in sharing. I tried to understand as explained above, but still confused about the whole process in the code on q_posterior. Why q_pred(log_x_t, t) can find Q_t, isn't Q_t known? and $q_pred(log_x_t, t)=log(Q_t^{line}*v(x_t))$ ? is right?

Frreed commented 1 year ago

@trestad Hi, thanks for your explanation. But i still comfused with p(x0|xt) * q(x_t-1|x0) / q(xt|x0) = q_pred(q, t-1), do you know why use q as input rather than use x0 as input ?

Ed-ivan commented 10 months ago

While I have a question.Since we have to replace the last value in rows with ct_cumprod, but mask (batch , 1, 256) is used on log_qt , it would not change all the rows ? Looking forward to your reply!

mounchiliu commented 1 week ago

I suppose, distribution q = normalized(p(x0 | xt) / q(xt | x0)), and q(x_t-1|x0) is approximate to q_pred(..., t-1) because it calculate hyper-parameters used in \bar(Qt-1), then p(x0|xt) * q(x_t-1|x0) / q(xt|x0) can be calculated with q_pred and distribution q, which is q_predict(q, t-1).