gojasper / flash-diffusion

Official implementation of ⚡ Flash Diffusion ⚡: Accelerating Any Conditional Diffusion Model for Few Steps Image Generation
https://gojasper.github.io/flash-diffusion-project/
Other
397 stars 26 forks source link
diffusion-models distillation dit inpainting sdxl super-resolution text-to-image

⚡ Flash Diffusion ⚡

This repository is the official implementation of the paper Flash Diffusion: Accelerating Any Conditional Diffusion Model for Few Steps Image Generation.



Images generated using 4 NFEs

In this paper, we propose an efficient, fast, versatile and LoRA-compatible distillation method to accelerate the generation of pre-trained diffusion models: Flash Diffusion. The method reaches state-of-the-art performances in terms of FID and CLIP-Score for few steps image generation on the COCO 2014 and COCO 2017 datasets, while requiring only several GPU hours of training and fewer trainable parameters than existing methods. In addition to its efficiency, the versatility of the method is also exposed across several tasks such as text-to-image, inpainting, face-swapping, super-resolution and using different diffusion models backbones either using a UNet-based denoisers (SD1.5, SDXL) or DiT (Pixart-α), as well as adapters. In all cases, the method allowed to reduce drastically the number of sampling steps while maintaining very high-quality image generation.

Quick access

Method

Our method aims to create a fast, reliable, and adaptable approach for various uses. We propose to train a student model to predict in a single step a denoised multiple-step teacher prediction of a corrupted input sample. Additionally, we sample timesteps from an adaptable distribution that shifts during training to help the student model target specific timesteps.

Results

Flash Diffusion is compatible with various backbones such as - [Flash Stable Diffusion 3](https://huggingface.co/jasperai/flash-sd3), distilled from a [Stable Diffusion 3 teacher](https://huggingface.co/stabilityai/stable-diffusion-3-medium) - [Flash SDXL](https://huggingface.co/jasperai/flash-sdxl), distilled from a [SDXL teacher](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) - [Flash Pixart (DiT)](https://huggingface.co/jasperai/flash-pixart), distilled from a [Pixart-α teacher](https://huggingface.co/PixArt-alpha/PixArt-XL-2-1024-MS) - [Flash SD](https://huggingface.co/jasperai/flash-sd), distilled from a [SD1.5 teacher](https://huggingface.co/runwayml/stable-diffusion-v1-5) It can also be used to accelerate existing LoRAs in a **training-free** manner. See this [section](#combining-flash-diffusion-with-existing-loras-) to know more. ### Varying backbones for *Text-to-image* #### Flash Stable Diffusion 3 (MMDiT)

Images generated using 4 NFEs

#### Flash SDXL (UNet)

Images generated using 4 NFEs

#### Flash Pixart (DiT)

Images generated using 4 NFEs

#### Flash SD

Images generated using 4 NFEs

### Varying Use-cases
Image-inpainting

Image-upscaling

Face-swapping

T2I-Adapters

### Training Free LoRA Acceleration #### SDXL LoRAs

Images generated using 4 NFEs

## Setup To be up and running, you need first to create a virtual env with at least `python3.10` installed and activate it ### With `venv` ```bash python3.10 -m venv envs/flash_diffusion source envs/flash_diffusion/bin/activate ``` ### With `conda` ```bash conda create -n flash_diffusion python=3.10 conda activate flash_diffusion ``` Then install the required dependencies (if on GPU) and the repo in editable mode ```bash pip install --upgrade pip pip install -r requirements.txt pip install -e . ``` ## Distilling existing T2I models The main scripts to reproduce the main experiments of the paper are located in the `examples`. We provide 4 diffirent scripts: - `train_flash_sd3.py`: Distils [SD3 model](https://huggingface.co/stabilityai/stable-diffusion-3-medium) - `train_flash_sdxl.py`: Distils [SDXL model](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) - `train_flash_pixart`: Distils [Pixart-α model](https://huggingface.co/PixArt-alpha/PixArt-XL-2-1024-MS) - `train_flash_canny_adapter.py`: Distils a [T2I Canny Adapter](https://huggingface.co/TencentARC/t2i-adapter-canny-sdxl-1.0?library=true) - `train_flash_sd.py`: Distils [SD1.5 model](https://huggingface.co/runwayml/stable-diffusion-v1-5) In `examples\configs`, you will find the configuration `yaml` associated to each script. The only thing you need is to amend the `SHARDS_PATH_OR_URLS` section of the `yaml` so the model is trained on your own data. Please note that this package uses [`webdataset`](https://github.com/webdataset/webdataset) to handle the datastream and so the urls you use should be fomatted according to the [`webdataset format`](https://github.com/webdataset/webdataset?tab=readme-ov-file#the-webdataset-format). In particular, for those 4 examples, each sample needs to be composed of a `jpg` file containing the image and a `json` file containing the caption under the key `caption` and the image aesthetics score `aesthetic_score`: ``` sample = { "jpg": dummy_image, "json": { "caption": "dummy caption", "aesthetic_score": 6.0 } } ``` The scripts can then be launched by simply runing ```bash # Set the number of gpus & nodes you want to use export SLURM_NPROCS=1 export SLURM_NNODES=1 # Distills SD1.5 python3.10 examples/train_flash_sd.py # Distills SDXL1.0 python3.10 examples/train_flash_sdxl.py # Distills Pixart-α python3.10 examples/train_flash_pixart.py # Distills T2I Canny adapter python3.10 examples/train_flash_canny_adapter.py ``` ## Example of a distillation training with a custom conditional diffusion model This package is also intended to support custom model distillation. ```python from copy import deepcopy from flash.models.unets import DiffusersUNet2DCondWrapper from flash.models.vae import AutoencoderKLDiffusers, AutoencoderKLDiffusersConfig from flash.models.embedders import ( ClipEmbedder, ClipEmbedderConfig, ClipEmbedderWithProjection, ConditionerWrapper, ) # Create the VAE vae_config = AutoencoderKLDiffusersConfig( "stabilityai/sdxl-vae" # VAE for HF Hub ) vae = AutoencoderKLDiffusers(config=vae_config) ## Create the Conditioners ## # A Clip conditioner returning 2 types of conditioning embedder_1_config = ClipEmbedderConfig( version="stabilityai/stable-diffusion-xl-base-1.0", # from HF Hub text_embedder_subfolder="text_encoder_2", tokenizer_subfolder="tokenizer_2", input_key="text", always_return_pooled=True, # Return a 1-dimensional tensor ) embeddder_1 = ClipEmbedder(config=embedder_1_config) # Embedder acting on a lr image injected in the UNET via concatenation embedder_2_config = TorchNNEmbedderConfig( nn_modules=["torch.nn.Conv2d"], nn_modules_kwargs=[ dict( in_channels=3, out_channels=6, kernel_size=3, padding=1, stride=2, ), ], input_key="downsampled_image", unconditional_conditioning_rate=request.param, ) embedder_2 = TorchNNEmbedder(config=embedder_2_config) conditioner_wrapper = ConditionerWrapper( conditioners=[embedder1, embedder2] ) # Create the Teacher denoiser unet = DiffusersUNet2DCondWrapper( in_channels=4 + 6, # VAE channels + concat conditioning out_channels=4, # VAE channels cross_attention_dim=1280, # cross-attention conditioning projection_class_embeddings_input_dim=1280, # add conditioning class_embed_type="projection", ) # Load the teacher weights ... # Create the student denoiser student_denoiser = deepcopy(teacher_denoiser) ``` ## Inference with a Hugging Face pipeline 🤗 ```python import torch from diffusers import PixArtAlphaPipeline, Transformer2DModel, LCMScheduler from peft import PeftModel # Load LoRA transformer = Transformer2DModel.from_pretrained( "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="transformer", torch_dtype=torch.float16 ) transformer = PeftModel.from_pretrained( transformer, "jasperai/flash-pixart" ) # Pipeline pipe = PixArtAlphaPipeline.from_pretrained( "PixArt-alpha/PixArt-XL-2-1024-MS", transformer=transformer, torch_dtype=torch.float16 ) # Scheduler pipe.scheduler = LCMScheduler.from_pretrained( "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="scheduler", timestep_spacing="trailing", ) pipe.to("cuda") prompt = "A raccoon reading a book in a lush forest." image = pipe(prompt, num_inference_steps=4, guidance_scale=0).images[0] ``` ## Using Flash in ComfyUI To use FlashSDXL locally using Comfyui you need to : 1) Make sure your comfyUI install is up to date 2) Download the checkpoint from huggingface. In case you wonder how, go to "Files and Version" go to [`comfy`](https://huggingface.co/jasperai/flash-sdxl/tree/main/comfy) folder and hit the download button next to the `FlashSDXL.safetensors` 3) Move the new checkpoint file to your local `comfyUI/models/loras/`. folder Use it as a LoRA on top of `sd_xl_base_1.0_0.9vae.safetensors`, a simple comfyui workflow.json is provided in this repo inc `examples/comfy` *Disclaimer : Model has been trained to work with a cfg scale of 1 and a lcm scheduler but parameters can be tweaked a bit.* ## Combining Flash Diffusion with Existing LoRAs 🎨 Flash Diffusion models can also be combined with existing LoRAs to unlock few steps generation in a **training free** manner. They can be integrated straight to Hugging Face pipelines. See an example below. ```python from diffusers import DiffusionPipeline, LCMScheduler import torch user_lora_id = "TheLastBen/Papercut_SDXL" trigger_word = "papercut" flash_lora_id = "jasperai/flash-sdxl" # Load Pipeline pipe = DiffusionPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", variant="fp16" ) # Set scheduler pipe.scheduler = LCMScheduler.from_config( pipe.scheduler.config ) # Load LoRAs pipe.load_lora_weights(flash_lora_id, adapter_name="flash") pipe.load_lora_weights(user_lora_id, adapter_name="lora") pipe.set_adapters(["flash", "lora"], adapter_weights=[1.0, 1.0]) pipe.to(device="cuda", dtype=torch.float16) prompt = f"{trigger_word} a cute corgi" image = pipe( prompt, num_inference_steps=4, guidance_scale=0 ).images[0] ``` # License This code is released under the [Creative Commons BY-NC 4.0 license](https://creativecommons.org/licenses/by-nc/4.0/legalcode.en). # Citation If you find this work useful or use it in your research, please consider citing us ```bibtex @misc{chadebec2024flash, title={Flash Diffusion: Accelerating Any Conditional Diffusion Model for Few Steps Image Generation}, author={Clement Chadebec and Onur Tasar and Eyal Benaroche and Benjamin Aubin}, year={2024}, eprint={2406.02347}, archivePrefix={arXiv}, primaryClass={cs.CV} } ```