thu-ml / DPM-Solver-v3

Official code for "DPM-Solver-v3: Improved Diffusion ODE Solver with Empirical Model Statistics" (NeurIPS 2023)
MIT License
99 stars 5 forks source link

DPM-Solver-v3

This repo is the official code for the paper DPM-Solver-v3: Improved Diffusion ODE Solver with Empirical Model Statistics (NeurIPS 2023).

Project Page | Paper | arXiv

DPM-Solver-v3 is a training-free ODE solver dedicated to fast sampling of diffusion models, equipped with precomputed empirical model statistics (EMS) to boost the convergence speed up to 40%. DPM-Solver-v3 brings especially notable and non-trivial quality improvement in few-step sampling (5~10 steps).

Please refer to the paper and project page for detailed methods and results.

Code Examples

We integrate DPM-Solver-v3 into various codebases, and the previous state-of-the-art samplers DPM-Solver++ and UniPC are also included, enabling convenient benchmarking and comparisons.

We put the code examples in codebases/. The experiment results reported in the paper can be reproduced by them.

Name Original Repository Pretrained Models Dataset Type
score_sde https://github.com/yang-song/score_sde_pytorch cifar10_ddpmpp_deep_continuous-checkpoint_8 CIFAR-10 Uncond/Pixel-Space
edm https://github.com/NVlabs/edm edm-cifar10-32x32-uncond-vp CIFAR-10 Uncond/Pixel-Space
guided-diffusion https://github.com/openai/guided-diffusion 256x256_diffusion/256x256_classifier ImageNet-256 Cond/Pixel-Space
stable-diffusion https://github.com/CompVis/latent-diffusion
https://github.com/CompVis/stable-diffusion
lsun_beds256-model
sd-v1-4
LSUN-Bedroom
MS-COCO2014
Uncond/Latent-Space
Cond/Latent-Space

Documentation

We provide the PyTorch implementation of DPM-Solver-v3 in a single file dpm_solver_v3.py. We suggest referring to the code examples for its practical usage in different settings.

To use DPM-Solver-v3, one can follow the steps below. Special thanks to DPM-Solver for their unified model wrapper to support various diffusion models.

1. Define Noise Schedule

The noise schedule $\alpha_t,\sigma_t$ defines the forward transition kernel from time $0$ to time $t$: $$ p(x_t|x_0)=\mathcal N(x_t;\alpha_tx_0,\sigma_t^2I) $$ or equivalently $$ x_t=\alpha_tx_0+\sigma_t\epsilon,\quad \epsilon\sim\mathcal N(0,I) $$ We support two main class of noise schedules:

Name Python Class Definition Type
Variance Preserving (VP) NoiseScheduleVP $\alpha_t^2+\sigma_t^2=1$ discrete/continuous
EDM (https://github.com/NVlabs/edm) NoiseScheduleEDM $\alpha_t=1,\sigma_t=t$ continuous

1.1. Discrete-time DPMs

VP

We support a picewise linear interpolation of $\log\alpha_{t}$ in the NoiseScheduleVP class to convert discrete noise schedules to continuous noise schedules.

We need either the $\beta_i$ array or the $\bar{\alpha}_i$ array (see DDPM for details) to define the noise schedule. The detailed relationship is: $$ \bar{\alpha}_i = \prod (1 - \beta_k) $$

$$ \alpha_{t_i} = \sqrt{\bar{\alpha}_i} $$

Define the discrete-time noise schedule by the $\beta_i$ array:

noise_schedule = NoiseScheduleVP(schedule='discrete', betas=betas)

Or define the discrete-time noise schedule by the $\bar{\alpha}_i$ array:

noise_schedule = NoiseScheduleVP(schedule='discrete', alphas_cumprod=alphas_cumprod)

1.2. Continuous-time DPMs

VP

We support both linear schedule and cosine schedule for the continuous-time DPMs.

Name $\alpha_t$ Example Paper
VP (linear) $e^{-\frac{1}{4}(\beta_1-\beta_0)t^2-\frac{1}{2}\beta_0t}$ DDPM,ScoreSDE
VP (cosine) $\frac{f(t)}{f(0)}$ ($f(t)=\cos\left(\frac{t+s}{1+s}\frac{\pi}{2}\right)$) improved-DDPM

Define the continuous-time linear noise schedule with $\beta_0=0.1,\beta_1=20$:

noise_schedule = NoiseScheduleVP(schedule='linear', continuous_beta_0=0.1, continuous_beta_1=20.)

Define the continuous-time cosine noise schedule with $s=0.008$:

noise_schedule = NoiseScheduleVP(schedule='cosine')
EDM
noise_schedule = NoiseScheduleEDM()

2. Define Model Wrapper

For a given diffusion model with an input of the time label (may be discrete-time labels (i.e. 0 to 999) or continuous-time times (i.e. 0 to 1)), and the output type of the model may be "noise" or "x_start" or "v" or "score" (see Model Types), we wrap the model function to the following format:

model_fn(x, t_continuous) -> noise

where t_continuous is the continuous time labels (i.e. 0 to 1), and the output type of the model is "noise", i.e. a noise prediction model. The wrapped continuous-time noise prediction model model_fn is used for DPM-Solver-v3.

Note that DPM-Solver-v3 only needs the noise prediction model (the $\epsilon_\theta(x_t, t)$ model, also as known as the "mean" model), so for diffusion models which predict both "mean" and "variance" (such as improved-DDPM), you need to firstly define another function by yourself to only output the "mean".

Model Types

We support the following four types of diffusion models. You can set the model type by the argument model_type in the function model_wrapper.

Model Type Training Objective Example Paper
"noise": noise prediction model $\epsilon_\theta$ $E{x{0},\epsilon,t}\left[\omega1(t)||\epsilon\theta(x_t,t)-\epsilon||_2^2\right]$ DDPM, Stable-Diffusion
"xstart": data prediction model $x\theta$ $E_{x_0,\epsilon,t}\left[\omega2(t)||x\theta(x_t,t)-x_0||_2^2\right]$ DALL·E 2
"v": velocity prediction model $v_\theta$ $E_{x_0,\epsilon,t}\left[\omega3(t)||v\theta(x_t,t)-(\alpha_t\epsilon - \sigma_t x_0)||_2^2\right]$ Imagen Video
"score": marginal score function $s_\theta$ $E_{x_0,\epsilon,t}\left[\omega_4(t)||\sigmat s\theta(x_t,t)+\epsilon||_2^2\right]$ ScoreSDE

Sampling Types

We support the following three types of sampling by diffusion models. You can set the argument guidance_type in the function model_wrapper.

Sampling Type Equation for Noise Prediction Model Example Paper
"uncond": unconditional sampling $\tilde\epsilon_\theta(xt,t)=\epsilon\theta(x_t,t)$ DDPM
"classifier": classifier guidance $\tilde\epsilon_\theta(xt,t,c)=\epsilon\theta(x_t,t)-s\cdot\sigmat\nabla{xt}\log q\phi(x_t,t,c)$ ADM, GLIDE
"classifier-free": classifier-free guidance (CFG) $\tilde\epsilon_\theta(xt,t,c)=s\cdot \epsilon\theta(xt,t,c)+(1-s)\cdot\epsilon\theta(x_t,t)$ DALL·E 2, Imagen, Stable-Diffusion

2.1. Sampling without Guidance (Unconditional)

The given model has the following format:

model(x_t, t_input, **model_kwargs) -> noise | x_start | v | score

We wrap the model by:

model_fn = model_wrapper(
    model,
    noise_schedule,
    model_type=model_type,  # "noise" or "x_start" or "v" or "score"
    model_kwargs=model_kwargs,  # additional inputs of the model
)

2.2. Sampling with Classifier Guidance (Conditional)

The given model has the following format:

model(x_t, t_input, **model_kwargs) -> noise | x_start | v | score

For DPMs with classifier guidance, we also combine the model output with the classifier gradient. We need to specify the classifier function and the guidance scale. The classifier function has the following format:

classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)

where t_input is the same time label as in the original diffusion model model, and cond is the condition (such as class labels).

We wrap the model by:

model_fn = model_wrapper(
    model,
    noise_schedule,
    model_type=model_type,  # "noise" or "x_start" or "v" or "score"
    model_kwargs=model_kwargs,  # additional inputs of the model
    guidance_type="classifier",
    condition=condition,  # conditional input of the classifier
    guidance_scale=guidance_scale,  # classifier guidance scale
    classifier_fn=classifier,
    classifier_kwargs=classifier_kwargs,  # other inputs of the classifier function
)

2.3. Sampling with Classifier-free Guidance (Conditional)

The given model has the following format:

model(x_t, t_input, cond, **model_kwargs) -> noise | x_start | v | score

Note that for classifier-free guidance, the model needs another input cond (such as the text prompt). If cond is a special variable unconditional_condition (such as the empty text ""), then the model output is the unconditional DPM output.

We wrap the model by:

model_fn = model_wrapper(
    model,
    noise_schedule,
    model_type=model_type,  # "noise" or "x_start" or "v" or "score"
    model_kwargs=model_kwargs,  # additional inputs of the model
    guidance_type="classifier-free",
    condition=condition,  # conditional input
    unconditional_condition=unconditional_condition,  # special unconditional condition variable for the unconditional model
    guidance_scale=guidance_scale,  # classifier-free guidance scale
)

3. Define DPM-Solver-v3 and Sample

After defining noise_schedule and model_fn, we can further use them to define DPM-Solver-v3 and generate samples.

First we define the DPM-Solver-v3 instance dpm_solver_v3, and it will automatically handle some necessary preprocessing.

dpm_solver_v3 = DPM_Solver_v3(
    statistics_dir="statistics/sd-v1-4/7.5_250_1024", 
    noise_schedule, 
    steps=10, 
    t_start=1.0, 
    t_end=1e-3, 
    skip_type="time_uniform", 
    degenerated=False,
)

Then we can use dpm_solver_v3.sample to quickly sample from DPMs. This function computes the ODE solution at time t_end by DPM-Solver-v3, given the initial x_T at time t_start.

x_sample = dpm_solver_v3.sample(
    x_T,
    model_fn,
    order=3,
    p_pseudo=False,
    use_corrector=True,
    c_pseudo=True,
    lower_order_final=True,
    half=False,
)

EMS Computing by Yourself

We provide the code example compute_EMS_scoresde.py for EMS computing in the score_sde codebase. It distributes the computation on multiple GPUs using torch.multiprocessing, and leverage torch.autograd.forward_ad for fast calculation of Jacobian-vector products (JVPs) with forward-mode automatic differentiation. The procedure is detailed in Appendix C.1.1 in the paper. The arguments correspond to the paper in the following way:

To compute the EMS for your own model, please change the model loading and dataset in the code.

Acknowledgement

Special thanks to DPM-Solver and DPM-Solver++ for their unified model wrapper to support various diffusion models.

The predictor-corrector method is inspired by UniPC.

We use the pretrained diffusion models and codebases provided by:

ScoreSDE, EDM, Guided-Diffusion, Latent-Diffusion, Stable-Diffusion

Citation

If you find our work useful in your research, please consider citing:

@inproceedings{zheng2023dpm,
    title={DPM-Solver-v3: Improved Diffusion ODE Solver with Empirical Model Statistics},
    author={Zheng, Kaiwen and Lu, Cheng and Chen, Jianfei and Zhu, Jun},
    booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
    year={2023}
}