fast-codi / CoDi

[CVPR24] CoDi: Conditional Diffusion Distillation for Higher-Fidelity and Faster Image Generation
https://fast-codi.github.io/
83 stars 1 forks source link
cvpr2024

CoDi: Conditional Diffusion Distillation (CVPR24)

CoDi: Conditional Diffusion Distillation for Higher-Fidelity and Faster Image Generation
Kangfu Mei 1, 2, Mauricio Delbracio 2, Hossein Talebi 2, Zhengzhong Tu 2, Vishal M. Patel 1, Peyman Milanfar 2
1Johns Hopkins University
2Google Research

[Paper] [Project Page]

Disclaimer: This is not an official Google product. This repository contains an unofficial implementation. Please refer to the official implementation at https://github.com/google-research/google-research/tree/master/CoDi.

Disclaimer: All models in this repository were trained using publicly available data.

Introduction

CoDi can efficiently distill the sampling steps of a conditional diffusion model from an unconditional one (e.g. StableDiffsusion), enabling rapid generation of high-quality images (i.e. 1-4 steps) under various conditional settings (e.g. Inpainting, InstructPix2Pix, etc.).

teaser

On the standard real-world image super-resolution benchmark, we show that CoDi is capable of achieving 50 steps sampling performance in terms of FID and LPIPS with 4 steps only. It largely outperforms previous guided-distillation and consistency model. On the less challenge tasks such text-guided inpainting, we show that a new parameter-efficient distillation first proposed by us can even beat the original 50 steps sampling in the FID and LPIPS metrics.

performance

News

Detail Contents

  1. Training CoDi on HuggingFace Data
  2. Training CoDi on Your Own Data
  3. Testing CoDi on Canny Images
  4. Citations
  5. Acknowledgement

Note: The following instructions are modified from https://github.com/huggingface/community-events/blob/main/jax-controlnet-sprint/README.md

Training CoDi on HuggingFace Data

All you need to do is to update the DATASET_NAME from the HuggingFace hub to train on (could be your own, possibly private, dataset). A good choice is to check the datasets under https://huggingface.co/spaces/jax-diffusers-event/leaderboard.

export HF_HOME="/data/kmei1/huggingface/"
export DISK_DIR="/data/kmei1/huggingface/cache"
export MODEL_DIR="stabilityai/stable-diffusion-2-1"
export OUTPUT_DIR="canny_model"
export DATASET_NAME="jax-diffusers-event/canny_diffusiondb"
export NCCL_P2P_DISABLE=1
export CUDA_VISIBLE_DEVICES=5
# export XLA_FLAGS="--xla_force_host_platform_device_count=4 --xla_dump_to=/tmp/foo"

python3 train_codi_flax.py \
 --pretrained_model_name_or_path $MODEL_DIR \
 --output_dir $OUTPUT_DIR \
 --dataset_name $DATASET_NAME \
 --load_from_disk \
 --cache_dir $DISK_DIR \
 --resolution 512 \
 --learning_rate 8e-6 \
 --train_batch_size 2 \
 --gradient_accumulation_steps 2 \
 --revision main \
 --from_pt \
 --mixed_precision bf16 \
 --max_train_steps 200_000 \
 --checkpointing_steps 10_000 \
 --validation_steps 100 \
 --dataloader_num_workers 8 \
 --distill_learning_steps 20 \
 --ema_decay 0.99995 \
 --onestepode uncontrol \
 --onestepode_control_params target \
 --onestepode_sample_eps vprediction \
 --cfg_aware_distill \
 --distill_loss consistency_x \
 --distill_type conditional \
 --image_column original_image \
 --caption_column prompt \
 --conditioning_image transformed_image \
 --report_to wandb \
 --validation_image "figs/control_bird_canny.png" \
 --validation_prompt "birds" \

Note that you may need to change the --image_column, --caption_column, and --conditioning_image according to your selected dataset. For example, you need to add these options for the jax-diffusers-event/canny_diffusiondb dataset according to this https://huggingface.co/datasets/jax-diffusers-event/canny_diffusiondb.

Training CoDi on Your Own Data

Data preprocessing

Here we demonstrate how to prepare a large dataset to train a ControlNet model that generates images conditioned on an image representation that only has edge information (using canny edge detection)

More specifically, we use an example script defined in https://github.com/huggingface/community-events/blob/main/jax-controlnet-sprint/dataset_tools/coyo_1m_dataset_preprocess.py:

python3 coyo_1m_dataset_preprocess.py \
 --train_data_dir="/data/dataset" \
 --cache_dir="/data" \
 --max_train_samples=1000000 \
 --num_proc=32

Once the script finishes running, you can find a data folder at the specified train_data_dir with the below folder structure:

data
├── images
│   ├── image_1.png
│   ├── .......
│   └── image_1000000.jpeg
├── processed_images
│   ├── image_1.png
│   ├── .......
│   └── image_1000000.jpeg
└── meta.jsonl

Training

All you need to do is to update the DATASET_DIR with the correct path to your data folder.

Here is an example to run a training script that will load the dataset from the disk

export HF_HOME="/data/huggingface/"
export DISK_DIR="/data/huggingface/cache"
export MODEL_DIR="runwayml/stable-diffusion-v1-5"
export OUTPUT_DIR="/data/canny_model"
export DATASET_DIR="/data/dataset"

python3 train_codi_flax.py \
 --pretrained_model_name_or_path=$MODEL_DIR \
 --output_dir=$OUTPUT_DIR \
 --train_data_dir=$DATASET_DIR \
 --load_from_disk \
 --cache_dir=$DISK_DIR \
 --resolution=512 \
 --learning_rate=1e-5 \
 --train_batch_size=2 \
 --revision="non-ema" \
 --from_pt \
 --max_train_steps=500000 \
 --checkpointing_steps=10000 \
 --dataloader_num_workers=16 \
 --distill_learning_steps 50 \
 --distill_timestep_scaling 10 \
 --onestepode control \
 --onestepode_control_params target \
 --onestepode_sample_eps v_prediction \
 --distill_loss consistency_x \

Testing CoDi on Canny Images

Prompt: birds
Canny Image Ours w. 4-step sampling

We provide the pretrained canny-edge-to-image model according to the Controlnet experiments https://huggingface.co/lllyasviel/sd-controlnet-canny. Note that we are using the open-sourced data, i.e., jax-diffusers-event/canny_diffusiondb, and thus there are difference in the styles between ControlNet's result and ours.

export HF_HOME="/data/kmei1/huggingface/"
export DISK_DIR="/data/kmei1/huggingface/cache"
export MODEL_DIR="stabilityai/stable-diffusion-2-1"
export NCCL_P2P_DISABLE=1
export CUDA_VISIBLE_DEVICES=5

# download pretrained checkpoint and relocate it.
wget https://www.cis.jhu.edu/~kmei1/publics/codi/canny_99000.tar.fz && tar -xzvf canny_99000.tar.fz -C experiments

python test_canny.py

# or gradio user interface
python gradio_canny_to_image.py

The user interface looks like this 👇 demo

Citations

You may want to cite:

@article{mei2023conditional,
  title={CoDi: Conditional Diffusion Distillation for Higher-Fidelity and Faster Image Generation},
  author={Mei, Kangfu and Delbracio, Mauricio and Talebi, Hossein and Tu, Zhengzhong and Patel, Vishal M and Milanfar, Peyman},
  journal={arXiv preprint arXiv:2310.01407},
  year={2023}
}

Acknowledgement

The codes are based on Diffusers and HuggingFace. Please also follow their licenses. Thanks for their awesome works.