yang-song / score_sde_pytorch

PyTorch implementation for Score-Based Generative Modeling through Stochastic Differential Equations (ICLR 2021, Oral)
https://arxiv.org/abs/2011.13456
Apache License 2.0
1.58k stars 295 forks source link
controllable-generation diffusion-models generative-models iclr-2021 inverse-problems pytorch score-based-generative-modeling score-matching stochastic-differential-equations

Score-Based Generative Modeling through Stochastic Differential Equations

PWC

This repo contains a PyTorch implementation for the paper Score-Based Generative Modeling through Stochastic Differential Equations

by Yang Song, Jascha Sohl-Dickstein, Diederik P. Kingma, Abhishek Kumar, Stefano Ermon, and Ben Poole


We propose a unified framework that generalizes and improves previous work on score-based generative models through the lens of stochastic differential equations (SDEs). In particular, we can transform data to a simple noise distribution with a continuous-time stochastic process described by an SDE. This SDE can be reversed for sample generation if we know the score of the marginal distributions at each intermediate time step, which can be estimated with score matching. The basic idea is captured in the figure below:

schematic

Our work enables a better understanding of existing approaches, new sampling algorithms, exact likelihood computation, uniquely identifiable encoding, latent code manipulation, and brings new conditional generation abilities (including but not limited to class-conditional generation, inpainting and colorization) to the family of score-based generative models.

All combined, we achieved an FID of 2.20 and an Inception score of 9.89 for unconditional generation on CIFAR-10, as well as high-fidelity generation of 1024px Celeba-HQ images (samples below). In addition, we obtained a likelihood value of 2.99 bits/dim on uniformly dequantized CIFAR-10 images.

FFHQ samples

What does this code do?

Aside from the NCSN++ and DDPM++ models in our paper, this codebase also re-implements many previous score-based models in one place, including NCSN from Generative Modeling by Estimating Gradients of the Data Distribution, NCSNv2 from Improved Techniques for Training Score-Based Generative Models, and DDPM from Denoising Diffusion Probabilistic Models.

It supports training new models, evaluating the sample quality and likelihoods of existing models. We carefully designed the code to be modular and easily extensible to new SDEs, predictors, or correctors.

Integration with 🤗 Diffusers library

Most models are now also available in 🧨 Diffusers and accesible via the ScoreSdeVE pipeline.

Diffusers allows you to test score sde based models in PyTorch in just a couple lines of code.

You can install diffusers as follows:

pip install diffusers torch accelerate

And then try out the models with just a couple lines of code:

from diffusers import DiffusionPipeline

model_id = "google/ncsnpp-ffhq-1024"

# load model and scheduler
sde_ve = DiffusionPipeline.from_pretrained(model_id)

# run pipeline in inference (sample random noise and denoise)
image = sde_ve().images[0]

# save image
image[0].save("sde_ve_generated_image.png")

More models can be found directly on the Hub.

JAX version

Please find a JAX implementation here, which additionally supports class-conditional generation with a pre-trained classifier, and resuming an evalution process after pre-emption.

JAX vs. PyTorch

In general, this PyTorch version consumes less memory but runs slower than JAX. Here is a benchmark on training an NCSN++ cont. model with VE SDE. Hardware is 4x Nvidia Tesla V100 GPUs (32GB) Framework Time (second per step) Memory usage in total (GB)
PyTorch 0.56 20.6
JAX (n_jitted_steps=1) 0.30 29.7
JAX (n_jitted_steps=5) 0.20 74.8

How to run the code

Dependencies

Run the following to install a subset of necessary python packages for our code

pip install -r requirements.txt

Stats files for quantitative evaluation

We provide the stats file for CIFAR-10. You can download cifar10_stats.npz and save it to assets/stats/. Check out #5 on how to compute this stats file for new datasets.

Usage

Train and evaluate our models through main.py.

main.py:
  --config: Training configuration.
    (default: 'None')
  --eval_folder: The folder name for storing evaluation results
    (default: 'eval')
  --mode: <train|eval>: Running mode: train or eval
  --workdir: Working directory

How to extend the code

Pretrained checkpoints

All checkpoints are provided in this Google drive.

Instructions: You may find two checkpoints for some models. The first checkpoint (with a smaller number) is the one that we reported FID scores in our paper's Table 3 (also corresponding to the FID and IS columns in the table below). The second checkpoint (with a larger number) is the one that we reported likelihood values and FIDs of black-box ODE samplers in our paper's Table 2 (also FID(ODE) and NNL (bits/dim) columns in the table below). The former corresponds to the smallest FID during the course of training (every 50k iterations). The later is the last checkpoint during training.

Per Google's policy, we cannot release our original CelebA and CelebA-HQ checkpoints. That said, I have re-trained models on FFHQ 1024px, FFHQ 256px and CelebA-HQ 256px with personal resources, and they achieved similar performance to our internal checkpoints.

Here is a detailed list of checkpoints and their results reported in the paper. FID (ODE) corresponds to the sample quality of black-box ODE solver applied to the probability flow ODE.

Checkpoint path FID IS FID (ODE) NNL (bits/dim)
ve/cifar10_ncsnpp/ 2.45 9.73 - -
ve/cifar10_ncsnpp_continuous/ 2.38 9.83 - -
ve/cifar10_ncsnpp_deep_continuous/ 2.20 9.89 - -
vp/cifar10_ddpm/ 3.24 - 3.37 3.28
vp/cifar10_ddpm_continuous - - 3.69 3.21
vp/cifar10_ddpmpp 2.78 9.64 - -
vp/cifar10_ddpmpp_continuous 2.55 9.58 3.93 3.16
vp/cifar10_ddpmpp_deep_continuous 2.41 9.68 3.08 3.13
subvp/cifar10_ddpm_continuous - - 3.56 3.05
subvp/cifar10_ddpmpp_continuous 2.61 9.56 3.16 3.02
subvp/cifar10_ddpmpp_deep_continuous 2.41 9.57 2.92 2.99
Checkpoint path Samples
ve/bedroom_ncsnpp_continuous bedroom_samples
ve/church_ncsnpp_continuous church_samples
ve/ffhq_1024_ncsnpp_continuous ffhq_1024
ve/ffhq_256_ncsnpp_continuous ffhq_256_samples
ve/celebahq_256_ncsnpp_continuous celebahq_256_samples

Demonstrations and tutorials

Link Description
Open In Colab Load our pretrained checkpoints and play with sampling, likelihood computation, and controllable synthesis (JAX + FLAX)
Open In Colab Load our pretrained checkpoints and play with sampling, likelihood computation, and controllable synthesis (PyTorch)
Open In Colab Tutorial of score-based generative models in JAX + FLAX
Open In Colab Tutorial of score-based generative models in PyTorch

Tips

References

If you find the code useful for your research, please consider citing

@inproceedings{
  song2021scorebased,
  title={Score-Based Generative Modeling through Stochastic Differential Equations},
  author={Yang Song and Jascha Sohl-Dickstein and Diederik P Kingma and Abhishek Kumar and Stefano Ermon and Ben Poole},
  booktitle={International Conference on Learning Representations},
  year={2021},
  url={https://openreview.net/forum?id=PxTIG12RRHS}
}

This work is built upon some previous papers which might also interest you: