crowsonkb / v-diffusion-pytorch

v objective diffusion inference code for PyTorch.
MIT License
715 stars 108 forks source link

v-diffusion-pytorch

v objective diffusion inference code for PyTorch, by Katherine Crowson (@RiversHaveWings) and Chainbreakers AI (@jd_pressman).

The models are denoising diffusion probabilistic models (https://arxiv.org/abs/2006.11239), which are trained to reverse a gradual noising process, allowing the models to generate samples from the learned data distributions starting from random noise. The models are also trained on continuous timesteps. They use the 'v' objective from Progressive Distillation for Fast Sampling of Diffusion Models (https://openreview.net/forum?id=TIdIXIpzhoI). Guided diffusion sampling scripts (https://arxiv.org/abs/2105.05233) are included, specifically CLIP guided diffusion. This repo also includes a diffusion model conditioned on CLIP text embeddings that supports classifier-free guidance (https://openreview.net/pdf?id=qw8AKxfYbI), similar to GLIDE (https://arxiv.org/abs/2112.10741). Sampling methods include DDPM, DDIM (https://arxiv.org/abs/2010.02502), and PRK/PLMS (https://openreview.net/forum?id=PlKWVd2yBkY).

Thank you to stability.ai for compute to train these models!

Installation

pip install v-diffusion-pytorch

or git clone then pip install -e .

Model checkpoints:

A 602M parameter CLIP conditioned model trained on Conceptual 12M for 3.1M steps and then fine-tuned for classifier-free guidance for 250K additional steps. This is the recommended model to use.

As above, before CFG fine-tuning. The model from the original release of this repo.

A 481M parameter unconditional model trained on a 33 million image original resolution subset of Yahoo Flickr Creative Commons 100 Million.

A 968M parameter unconditional model trained on a 33 million image original resolution subset of Yahoo Flickr Creative Commons 100 Million.

It also contains PyTorch ports of the four models from v-diffusion-jax, danbooru_128, imagenet_128, wikiart_128, wikiart_256:

Sampling

Example

If the model checkpoint for cc12m_1_cfg is stored in checkpoints/, the following will generate four images:

./cfg_sample.py "the rise of consciousness":5 -n 4 -bs 4 --seed 0

If they are somewhere else, you need to specify the path to the checkpoint with --checkpoint.

Colab

There is a cc12m_1_cfg Colab (a simplified version of cfg_sample.py) here, which can be used for free.

CFG sampling (best, but only cc12m_1_cfg supports it)

usage: cfg_sample.py [-h] [--images [IMAGE ...]] [--batch-size BATCH_SIZE]
                     [--checkpoint CHECKPOINT] [--device DEVICE] [--eta ETA] [--init INIT]
                     [--method {ddpm,ddim,prk,plms,pie,plms2,iplms}] [--model {cc12m_1_cfg}]
                     [-n N] [--seed SEED] [--size SIZE SIZE]
                     [--starting-timestep STARTING_TIMESTEP] [--steps STEPS]
                     [prompts ...]

prompts: the text prompts to use. Weights for text prompts can be specified by putting the weight after a colon, for example: "the rise of consciousness:5". A weight of 1 will sample images that match the prompt roughly as well as images usually match prompts like that in the training set. The default weight is 3.

--batch-size: sample this many images at a time (default 1)

--checkpoint: manually specify the model checkpoint file

--device: the PyTorch device name to use (default autodetects)

--eta: set to 0 (the default) while using --method ddim for deterministic (DDIM) sampling, 1 for stochastic (DDPM) sampling, and in between to interpolate between the two.

--images: the image prompts to use (local files or HTTP(S) URLs). Weights for image prompts can be specified by putting the weight after a colon, for example: "image_1.png:5". The default weight is 3.

--init: specify the init image (optional)

--method: specify the sampling method to use (DDPM, DDIM, PRK, PLMS, PIE, PLMS2, or IPLMS) (default PLMS). DDPM is the original SDE sampling method, DDIM integrates the probability flow ODE using a first order method, PLMS is fourth-order pseudo Adams-Bashforth, and PLMS2 is second-order pseudo Adams-Bashforth. PRK (fourth-order Pseudo Runge-Kutta) and PIE (second-order Pseudo Improved Euler) are used to bootstrap PLMS and PLMS2 but can be used on their own if you desire (slow). IPLMS is the fourth order "Improved PLMS" sampler from (Fast Sampling of Diffusion Models with Exponential Integrator)[https://arxiv.org/abs/2204.13902].

--model: specify the model to use (default cc12m_1_cfg)

-n: sample until this many images are sampled (default 1)

--seed: specify the random seed (default 0)

--starting-timestep: specify the starting timestep if an init image is used (range 0-1, default 0.9)

--size: the output image size (default auto)

--steps: specify the number of diffusion timesteps (default is 50, can be lower for faster but lower quality sampling, must be much higher with DDIM and especially DDPM)

CLIP guided sampling (all models)

usage: clip_sample.py [-h] [--images [IMAGE ...]] [--batch-size BATCH_SIZE]
                      [--checkpoint CHECKPOINT] [--clip-guidance-scale CLIP_GUIDANCE_SCALE]
                      [--cutn CUTN] [--cut-pow CUT_POW] [--device DEVICE] [--eta ETA]
                      [--init INIT] [--method {ddpm,ddim,prk,plms,pie,plms2,iplms}]
                      [--model {cc12m_1,cc12m_1_cfg,danbooru_128,imagenet_128,wikiart_128,wikiart_256,yfcc_1,yfcc_2}]
                      [-n N] [--seed SEED] [--size SIZE SIZE]
                      [--starting-timestep STARTING_TIMESTEP] [--steps STEPS]
                      [prompts ...]

prompts: the text prompts to use. Relative weights for text prompts can be specified by putting the weight after a colon, for example: "the rise of consciousness:0.5".

--batch-size: sample this many images at a time (default 1)

--checkpoint: manually specify the model checkpoint file

--clip-guidance-scale: how strongly the result should match the text prompt (default 500). If set to 0, the cc12m_1 model will still be CLIP conditioned and sampling will go faster and use less memory.

--cutn: the number of random crops to compute CLIP embeddings for (default 16)

--cut-pow: the random crop size power (default 1)

--device: the PyTorch device name to use (default autodetects)

--eta: set to 0 (the default) while using --method ddim for deterministic (DDIM) sampling, 1 for stochastic (DDPM) sampling, and in between to interpolate between the two.

--images: the image prompts to use (local files or HTTP(S) URLs). Relative weights for image prompts can be specified by putting the weight after a colon, for example: "image_1.png:0.5".

--init: specify the init image (optional)

--method: specify the sampling method to use (DDPM, DDIM, PRK, PLMS, PIE, PLMS2, or IPLMS) (default PLMS). DDPM is the original SDE sampling method, DDIM integrates the probability flow ODE using a first order method, PLMS is fourth-order pseudo Adams-Bashforth, and PLMS2 is second-order pseudo Adams-Bashforth. PRK (fourth-order Pseudo Runge-Kutta) and PIE (second-order Pseudo Improved Euler) are used to bootstrap PLMS and PLMS2 but can be used on their own if you desire (slow). IPLMS is the fourth order "Improved PLMS" sampler from (Fast Sampling of Diffusion Models with Exponential Integrator)[https://arxiv.org/abs/2204.13902].

--model: specify the model to use (default cc12m_1)

-n: sample until this many images are sampled (default 1)

--seed: specify the random seed (default 0)

--starting-timestep: specify the starting timestep if an init image is used (range 0-1, default 0.9)

--size: the output image size (default auto)

--steps: specify the number of diffusion timesteps (default is 1000, can lower for faster but lower quality sampling)