NVlabs / I2SB

Other
225 stars 20 forks source link

Why do you sample this way? #1

Closed jan-pair closed 1 year ago

jan-pair commented 1 year ago

Hi, thanks for awesome work!

In the paper, it is mentioned that

(ii)during generation when only X1∼pB is given, running standard DDPM starting from X1 induces the same marginal density of SB paths so long as the predicted X0 is close to X0.

It is also stated in the algorithm 2 that the previous state is sampled "according to DDPM", but I haven't found details on what are the parameters of the said DDPM.

In the code, as I understand, the deterministic sampling is done as follows:

pred_x0 = xt - std_fwd * net(xt, step)
mu_x0 = std_fwd_prev ** 2 / std_fwd ** 2
xt = mu_x0 * pred_x0 + (1 - mu_x0) * xt

which can be re-written as:

xt = xt - std_fwd_prev ** 2 / std_fwd * net(xt, step)

but this doesn't correspond to any known DDPM VP/VE sampling as far as I know.

Am I missing something? Can you please clarify?

Thanks in advance!

jan-pair commented 1 year ago

Apologies, it looks like I misinterpreted the computations inside:

compute_gaussian_product_coef(a, b) returns b^2 / (a^2 + b^2), a^2/(a^2 + b^2) as its first two outputs, not a^2 / (a^2 + b^2), b^2 / (a^2 + b^2)

as a result, the deterministic sampler is given by:

x_t = x_t - std_fwd * (1 - std_fwd_prev ** 2 / std_fwd ** 2) * net(xt, step)

which is also described in appendix in proof of proposition 3.3 (formula for p(x_n | x_0, x_{n+1}) ). For cross-reference, the similar formula can also be found in "Score-based Generative Modeling Through Stochastic Differential Equations" by Song et al. in appendix F when they expand q(x_{i-1} | x_i, x_0).

In general one can use any sampler that was designed for Variance-Exploding process by setting its sigmas as std_fwd. Compared with unconditional setting, the key difference seems to be that the maximum value of sigma is not that big (roughly 0.375 against 200+ on Imagenet). Arguably, the value of max sigma is the main hyperparameter you need to tune for your specific task for this method to work properly.

Closing this for now as solved, thank you!