FutureXiang / Diffusion

Minimal multi-gpu implementation of Diffusion Models with Classifier-Free Guidance (CFG)
50 stars 5 forks source link

Minimal Diffusion

This is a multi-gpu PyTorch implementation of Diffusion Models with Classifier-Free Guidance (CFG).

This repo contains:

Requirements

Usage

Training

To train a diffusion model with 4 GPUs and AMP enabled, for example, run:

python -m torch.distributed.launch --nproc_per_node=4
  # unconditional
  cifar.py       --config config/cifar_unconditional.yaml --use_amp

  # conditional
  cifar_guide.py --config config/cifar_conditional.yaml   --use_amp

Sampling

To generate 50000 image samples with a trained diffusion model, for example, run:

python -m torch.distributed.launch --nproc_per_node=4
  # unconditional
  # deterministic fast sampling (i.e. DDIM 100 steps / EDM 18 steps)
  sample.py       --config config/cifar_unconditional.yaml --use_amp
  # stochastic sampling (i.e. DDPM 1000 steps)
  sample.py       --config config/cifar_unconditional.yaml --use_amp --mode DDPM

  # conditional
  # deterministic fast sampling (i.e. DDIM 100 steps / EDM 18 steps)
  sample_guide.py --config config/cifar_conditional.yaml   --use_amp
  # stochastic sampling (i.e. DDPM 1000 steps)
  sample_guide.py --config config/cifar_conditional.yaml   --use_amp --mode DDPM

Compute FID

To calculate the FID metric on the training set, for example, run:

python -m pytorch_fid   data/cifar10-pngs/  2000_unconditional/EMAgenerated_ep1999_ddpm/pngs/

Results

Config Model Network Cond FID (best sampler/NFE) FID (fast sampler/NFE)
cifar_unconditional.yaml DDPM 35.7M UNet no 3.03 (DDPM / 1000) 3.54 (DDIM / 100)
cifar_unconditional_cosine.yaml DDPM (cosine) 35.7M UNet no 3.18 (DDPM / 1000) 3.56 (DDIM / 100)
cifar_conditional.yaml DDPM 38.6M UNet yes 3.27 (DDPM / 2000) 3.46 (DDIM / 200)
cifar_unconditional_vit.yaml DDPM 44.3M UViT no 3.33 (DDPM / 1000) 4.51 (DDIM / 100)
cifar_conditional_vit.yaml DDPM 44.3M UViT yes 2.87 (DDPM / 2000) 3.19 (DDIM / 200)
cifar_unconditional_EDM.yaml EDM 56.5M UNet no 2.22 (EDM / 35) 2.22 (EDM / 35)
cifar_conditional_EDM.yaml EDM 61.1M UNet yes 2.06 (EDM / 53) 2.06 (EDM / 53)
N/A EDM 44.3M UViT no NOT WORKING NOT WORKING

Acknowledgments

This repository is built on numerous open-source codebases such as DDPM, DDPM-pytorch, DDIM, EDM, ADM, U-ViT, and CFG-MNIST.