In this paper, we propose an efficient, fast, versatile and LoRA-compatible distillation method to accelerate the generation of pre-trained diffusion models: Flash Diffusion. The method reaches state-of-the-art performances in terms of FID and CLIP-Score for few steps image generation on the COCO 2014 and COCO 2017 datasets, while requiring only several GPU hours of training and fewer trainable parameters than existing methods. In addition to its efficiency, the versatility of the method is also exposed across several tasks such as text-to-image, inpainting, face-swapping, super-resolution and using different diffusion models backbones either using a UNet-based denoisers (SD1.5, SDXL) or DiT (Pixart-α), as well as adapters. In all cases, the method allowed to reduce drastically the number of sampling steps while maintaining very high-quality image generation.
Our method aims to create a fast, reliable, and adaptable approach for various uses. We propose to train a student model to predict in a single step a denoised multiple-step teacher prediction of a corrupted input sample. Additionally, we sample timesteps from an adaptable distribution that shifts during training to help the student model target specific timesteps.
Results
Flash Diffusion is compatible with various backbones such as
- [Flash Stable Diffusion 3](https://huggingface.co/jasperai/flash-sd3), distilled from a [Stable Diffusion 3 teacher](https://huggingface.co/stabilityai/stable-diffusion-3-medium)
- [Flash SDXL](https://huggingface.co/jasperai/flash-sdxl), distilled from a [SDXL teacher](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
- [Flash Pixart (DiT)](https://huggingface.co/jasperai/flash-pixart), distilled from a [Pixart-α teacher](https://huggingface.co/PixArt-alpha/PixArt-XL-2-1024-MS)
- [Flash SD](https://huggingface.co/jasperai/flash-sd), distilled from a [SD1.5 teacher](https://huggingface.co/runwayml/stable-diffusion-v1-5)
It can also be used to accelerate existing LoRAs in a **training-free** manner. See this [section](#combining-flash-diffusion-with-existing-loras-) to know more.
### Varying backbones for *Text-to-image*
#### Flash Stable Diffusion 3 (MMDiT)
Images generated using 4 NFEs
#### Flash SDXL (UNet)
Images generated using 4 NFEs
#### Flash Pixart (DiT)
Images generated using 4 NFEs
#### Flash SD
Images generated using 4 NFEs
### Varying Use-cases
Image-inpainting
Image-upscaling
Face-swapping
T2I-Adapters
### Training Free LoRA Acceleration
#### SDXL LoRAs
Images generated using 4 NFEs
## Setup
To be up and running, you need first to create a virtual env with at least `python3.10` installed and activate it
### With `venv`
```bash
python3.10 -m venv envs/flash_diffusion
source envs/flash_diffusion/bin/activate
```
### With `conda`
```bash
conda create -n flash_diffusion python=3.10
conda activate flash_diffusion
```
Then install the required dependencies (if on GPU) and the repo in editable mode
```bash
pip install --upgrade pip
pip install -r requirements.txt
pip install -e .
```
## Distilling existing T2I models
The main scripts to reproduce the main experiments of the paper are located in the `examples`. We provide 4 diffirent scripts:
- `train_flash_sd3.py`: Distils [SD3 model](https://huggingface.co/stabilityai/stable-diffusion-3-medium)
- `train_flash_sdxl.py`: Distils [SDXL model](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
- `train_flash_pixart`: Distils [Pixart-α model](https://huggingface.co/PixArt-alpha/PixArt-XL-2-1024-MS)
- `train_flash_canny_adapter.py`: Distils a [T2I Canny Adapter](https://huggingface.co/TencentARC/t2i-adapter-canny-sdxl-1.0?library=true)
- `train_flash_sd.py`: Distils [SD1.5 model](https://huggingface.co/runwayml/stable-diffusion-v1-5)
In `examples\configs`, you will find the configuration `yaml` associated to each script. The only thing you need is to amend the `SHARDS_PATH_OR_URLS` section of the `yaml` so the model is trained on your own data. Please note that this package uses [`webdataset`](https://github.com/webdataset/webdataset) to handle the datastream and so the urls you use should be fomatted according to the [`webdataset format`](https://github.com/webdataset/webdataset?tab=readme-ov-file#the-webdataset-format). In particular, for those 4 examples, each sample needs to be composed of a `jpg` file containing the image and a `json` file containing the caption under the key `caption` and the image aesthetics score `aesthetic_score`:
```
sample = {
"jpg": dummy_image,
"json": {
"caption": "dummy caption",
"aesthetic_score": 6.0
}
}
```
The scripts can then be launched by simply runing
```bash
# Set the number of gpus & nodes you want to use
export SLURM_NPROCS=1
export SLURM_NNODES=1
# Distills SD1.5
python3.10 examples/train_flash_sd.py
# Distills SDXL1.0
python3.10 examples/train_flash_sdxl.py
# Distills Pixart-α
python3.10 examples/train_flash_pixart.py
# Distills T2I Canny adapter
python3.10 examples/train_flash_canny_adapter.py
```
## Example of a distillation training with a custom conditional diffusion model
This package is also intended to support custom model distillation.
```python
from copy import deepcopy
from flash.models.unets import DiffusersUNet2DCondWrapper
from flash.models.vae import AutoencoderKLDiffusers, AutoencoderKLDiffusersConfig
from flash.models.embedders import (
ClipEmbedder,
ClipEmbedderConfig,
ClipEmbedderWithProjection,
ConditionerWrapper,
)
# Create the VAE
vae_config = AutoencoderKLDiffusersConfig(
"stabilityai/sdxl-vae" # VAE for HF Hub
)
vae = AutoencoderKLDiffusers(config=vae_config)
## Create the Conditioners ##
# A Clip conditioner returning 2 types of conditioning
embedder_1_config = ClipEmbedderConfig(
version="stabilityai/stable-diffusion-xl-base-1.0", # from HF Hub
text_embedder_subfolder="text_encoder_2",
tokenizer_subfolder="tokenizer_2",
input_key="text",
always_return_pooled=True, # Return a 1-dimensional tensor
)
embeddder_1 = ClipEmbedder(config=embedder_1_config)
# Embedder acting on a lr image injected in the UNET via concatenation
embedder_2_config = TorchNNEmbedderConfig(
nn_modules=["torch.nn.Conv2d"],
nn_modules_kwargs=[
dict(
in_channels=3,
out_channels=6,
kernel_size=3,
padding=1,
stride=2,
),
],
input_key="downsampled_image",
unconditional_conditioning_rate=request.param,
)
embedder_2 = TorchNNEmbedder(config=embedder_2_config)
conditioner_wrapper = ConditionerWrapper(
conditioners=[embedder1, embedder2]
)
# Create the Teacher denoiser
unet = DiffusersUNet2DCondWrapper(
in_channels=4 + 6, # VAE channels + concat conditioning
out_channels=4, # VAE channels
cross_attention_dim=1280, # cross-attention conditioning
projection_class_embeddings_input_dim=1280, # add conditioning
class_embed_type="projection",
)
# Load the teacher weights
...
# Create the student denoiser
student_denoiser = deepcopy(teacher_denoiser)
```
## Inference with a Hugging Face pipeline 🤗
```python
import torch
from diffusers import PixArtAlphaPipeline, Transformer2DModel, LCMScheduler
from peft import PeftModel
# Load LoRA
transformer = Transformer2DModel.from_pretrained(
"PixArt-alpha/PixArt-XL-2-1024-MS",
subfolder="transformer",
torch_dtype=torch.float16
)
transformer = PeftModel.from_pretrained(
transformer,
"jasperai/flash-pixart"
)
# Pipeline
pipe = PixArtAlphaPipeline.from_pretrained(
"PixArt-alpha/PixArt-XL-2-1024-MS",
transformer=transformer,
torch_dtype=torch.float16
)
# Scheduler
pipe.scheduler = LCMScheduler.from_pretrained(
"PixArt-alpha/PixArt-XL-2-1024-MS",
subfolder="scheduler",
timestep_spacing="trailing",
)
pipe.to("cuda")
prompt = "A raccoon reading a book in a lush forest."
image = pipe(prompt, num_inference_steps=4, guidance_scale=0).images[0]
```
## Using Flash in ComfyUI
To use FlashSDXL locally using Comfyui you need to :
1) Make sure your comfyUI install is up to date
2) Download the checkpoint from huggingface. In case you wonder how, go to "Files and Version" go to [`comfy`](https://huggingface.co/jasperai/flash-sdxl/tree/main/comfy) folder and hit the download button next to the `FlashSDXL.safetensors`
3) Move the new checkpoint file to your local `comfyUI/models/loras/`. folder
Use it as a LoRA on top of `sd_xl_base_1.0_0.9vae.safetensors`, a simple comfyui workflow.json is provided in this repo inc `examples/comfy`
*Disclaimer : Model has been trained to work with a cfg scale of 1 and a lcm scheduler but parameters can be tweaked a bit.*
## Combining Flash Diffusion with Existing LoRAs 🎨
Flash Diffusion models can also be combined with existing LoRAs to unlock few steps generation in a **training free** manner. They can be integrated straight to Hugging Face pipelines. See an example below.
```python
from diffusers import DiffusionPipeline, LCMScheduler
import torch
user_lora_id = "TheLastBen/Papercut_SDXL"
trigger_word = "papercut"
flash_lora_id = "jasperai/flash-sdxl"
# Load Pipeline
pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
variant="fp16"
)
# Set scheduler
pipe.scheduler = LCMScheduler.from_config(
pipe.scheduler.config
)
# Load LoRAs
pipe.load_lora_weights(flash_lora_id, adapter_name="flash")
pipe.load_lora_weights(user_lora_id, adapter_name="lora")
pipe.set_adapters(["flash", "lora"], adapter_weights=[1.0, 1.0])
pipe.to(device="cuda", dtype=torch.float16)
prompt = f"{trigger_word} a cute corgi"
image = pipe(
prompt,
num_inference_steps=4,
guidance_scale=0
).images[0]
```
# License
This code is released under the [Creative Commons BY-NC 4.0 license](https://creativecommons.org/licenses/by-nc/4.0/legalcode.en).
# Citation
If you find this work useful or use it in your research, please consider citing us
```bibtex
@misc{chadebec2024flash,
title={Flash Diffusion: Accelerating Any Conditional Diffusion Model for Few Steps Image Generation},
author={Clement Chadebec and Onur Tasar and Eyal Benaroche and Benjamin Aubin},
year={2024},
eprint={2406.02347},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```