NVlabs / edm

Elucidating the Design Space of Diffusion-Based Generative Models (EDM)
Other
1.42k stars 147 forks source link

Deterministic sampling and sampling steps #4

Open FutureXiang opened 1 year ago

FutureXiang commented 1 year ago

Hi, I try to train the EDM model with a simpler 35.7M #params UNet (proposed by original DDPM paper) and compare the result with DDPM/DDIM. I notice that $S_{churn} = 0$ leads to deterministic sampling, and $\gammai = \sqrt{2}-1$ leads to "max" stochastic sampling. So I introduce a parameter $\eta = \frac{S{churn} / N}{\sqrt{2}-1}$ to control stochasticity by interpolations. That is to say, $\gamma_i = (\sqrt{2}-1) * \eta$. Like in DDIM, $\eta = 0$ means deterministic, $\eta = 1$ means "max" stochastic.

I set different $\eta$ s and different steps to observe FIDs:

$\eta$/steps steps=18 steps=50 steps=100
$\eta=0.0$ 3.39 3.64 3.68
$\eta=0.5$ 3.10 2.95 2.93
$\eta=1.0$ 3.12 2.84 2.97

The FID is supposed to decrease when using more sampling steps, right? But why the FID gets worse for deterministic sampling? However it performs normally when $\eta=0.5$, and it increases again from 50 steps to 100 steps @ $\eta=1.0$. Why the behavior is so unstable and unpredictable?

To confirm it's not a bug, I train a model with your official codebase under the simpler setting close to DDPM (duration=100, augment=None, xflip=True; channel_mult=[1,2,2,2], num_blocks=2). The results are:

$\eta$/steps steps=18 steps=50
$\eta=0.0$ 2.94 3.09
$\eta=0.5$ 2.80 2.75
$\eta=1.0$ 2.95 2.78

For deterministic sampling, the FID is still getting worse when using more steps. When $\eta > 0$, the FID slightly gets better when steps increase. If the hyper-parameter settings and the corresponding performance are not consistently predictable, then how to obtain a good model under different datasets? Only by brute force & grid search?

Could you please provide some explanation and thoughts? Thanks a lot!