HaozheLiu-ST / T-GATE

T-GATE: Temporally Gating Attention to Accelerate Diffusion Model for Free!
MIT License
310 stars 19 forks source link
cross-attention cross-attention-diffusers diffusers diffusion efficiency inference pytorch text2image training-free transformer

T-GATE: Temporally Gating Attention to Accelerate Diffusion Model for Free! 🥳

GitHub arxiv GitHub release

TGATE-V1: Cross-Attention Makes Inference Cumbersome in Text-to-Image Diffusion Models
Wentian Zhang*  Haozhe Liu1*  Jinheng Xie2*  Francesco Faccio1,3  Mike Zheng Shou2  Jürgen Schmidhuber1,3 

1 AI Initiative, King Abdullah University of Science And Technology  

2 Show Lab, National University of Singapore   3 The Swiss AI Lab, IDSIA

TGATE-V2: Faster Diffusion Through Temporal Attention Decomposition
Haozhe Liu1,4*  Wentian Zhang*  Jinheng Xie2*  Francesco Faccio1,3  Mengmeng Xu4  Tao Xiang4  Mike Zheng Shou2  Juan-Manuel Pérez-Rúa4  Jürgen Schmidhuber1,3 

1 AI Initiative, King Abdullah University of Science And Technology  

2 Show Lab, National University of Singapore   3 The Swiss AI Lab, IDSIA   4 Meta

Code and Technical Report will be released soon!

Quick Introduction

We explore the role of the attention mechanism during inference in text-conditional diffusion models. Empirical observations suggest that cross-attention outputs converge to a fixed point after several inference steps. The convergence time naturally divides the entire inference process into two phases: an initial phase for planning text-oriented visual semantics, which are then translated into images in a subsequent fidelity-improving phase. Cross-attention is essential in the initial phase but almost irrelevant thereafter. Self-attention, however, initially plays a minor role but becomes increasingly important in the second phase. These findings yield a simple and training-free method called TGATE which efficiently generates images by caching and reusing attention outputs at scheduled time steps. Experiments show TGATE’s broad applicability to various existing text-conditional diffusion models which it speeds up by 10-50%.

The images generated by the diffusion model with or without TGATE. Our method can accelerate the diffusion model without generation performance drops. It is training-free and can be widely complementary to the existing studies.

🚀 Major Features

📄 Updates

📖 Key Observation

The images generated by the diffusion model at different denoising steps. The first row feeds the text embedding to the cross-attention modules for all steps. The second row only uses the text embedding from the first step to the 10th step, and the third row inputs the text embedding from the 11th to the 25th step.

We summarize our observations as follows:

🖊️ Method

if cross_attn and (gate_step<cur_step):
    hidden_states = cache

📄 Results

Model MACs Param Latency Zero-shot 10K-FID on MS-COCO
SD-1.5 16.938T 859.520M 7.032s 23.927
SD-1.5 w/ TGATE 9.875T 815.557M 4.313s 20.789
SD-2.1 38.041T 865.785M 16.121s 22.609
SD-2.1 w/ TGATE 22.208T 815.433 M 9.878s 19.940
SD-XL 149.438T 2.570B 53.187s 24.628
SD-XL w/ TGATE 84.438T 2.024B 27.932s 22.738
Pixart-Alpha 107.031T 611.350M 61.502s 38.669
Pixart-Alpha w/ TGATE 65.318T 462.585M 37.867s 35.825
DeepCache (SD-XL) 57.888T - 19.931s 23.755
DeepCache w/ TGATE 43.868T - 14.666s 23.999
LCM (SD-XL) 11.955T 2.570B 3.805s 25.044
LCM w/ TGATE 11.171T 2.024B 3.533s 25.028
LCM (Pixart-Alpha) 8.563T 611.350M 4.733s 36.086
LCM w/ TGATE 7.623T 462.585M 4.543s 37.048

The latency is tested on a 1080ti commercial card.

The MACs and Params are calculated by calflops.

The FID is calculated by PytorchFID.

🛠️ Requirements

🌟 Usage

Examples

To use TGATE for accelerating the denoising process, you can simply use main.py. For example,

python main.py \
--prompt 'A coral reef bustling with diverse marine life.' \
--model 'sd_2.1' \
--gate_step 10 \ 
--saved_path './sd_2_1.png' \
--inference_step 25 \
python main.py \
--prompt 'Astronaut in a jungle, cold color palette, muted colors, detailed, 8k' \
--model 'sd_xl' \
--gate_step 10 \ 
--saved_path './sd_xl.png' \
--inference_step 25 \
python main.py \
--prompt 'An alpaca made of colorful building blocks, cyberpunk.' \
--model 'pixart' \
--gate_step 8 \ 
--saved_path './pixart_alpha.png' \
--inference_step 25 \
python main.py \
--prompt 'Self-portrait oil painting, a beautiful cyborg with golden hair, 8k' \
--model 'lcm_sdxl' \
--gate_step 1 \ 
--saved_path './lcm_sdxl.png' \
--inference_step 4 \
python main.py \
--prompt 'A haunted Victorian mansion under a full moon.' \
--model 'sd_xl' \
--gate_step 10 \ 
--saved_path './sd_xl_deepcache.png' \
--inference_step 25 \
--deepcache \
  1. For LCMs, gate_step is set as 1 or 2, and inference step is set as 4.

  2. To use DeepCache, deepcache is set as True.

Third-party Usage

📖 Related works:

We encourage the users to read DeepCache and Adaptive Guidance

Methods U-Net Transformer Consistency Model
DeepCache -
Adaptive Guidance
TGATE (Ours)

Compared with DeepCache:

Compared with Adaptive Guidance:

Acknowledgment

Citation

If you find our work inspiring or use our codebase in your research, please consider giving a star ⭐ and a citation.

@article{tgate,
  title={Cross-Attention Makes Inference Cumbersome in Text-to-Image Diffusion Models},
  author={Zhang, Wentian and Liu, Haozhe and Xie, Jinheng and Faccio, Francesco and Shou, Mike Zheng and Schmidhuber, J{\"u}rgen},
  journal={arXiv preprint arXiv:2404.02747},
  year={2024}
}