thu-nics / ViDiT-Q

ViDiT-Q: Efficient and Accurate Quantization of Diffusion Transformers for Image and Video Generation
https://a-suozhang.xyz/viditq.github.io/
36 stars 4 forks source link
diffusion-models efficientml mixed-precision quantization

drawing ViDiT-Q: Efficient and Accurate Quantization of Diffusion Transformers for Image and Video Generation

arxiv Project Page

News


This repo contains the official code of ViDiT-Q: Efficient and Accurate Quantization of Diffusion Transformers for Image and Video Generation

We introduce ViDiT-Q, a quantization method specialized for diffusion transformers. For popular large-scale models (e.g., open-sora, Latte, Pixart-α, Pixart-Σ) for the video and image generation task, ViDiT-Q could achieve W8A8 quantization without metric degradation, and W4A8 without notable visual quality degradation.

ViDiT-Q could generate videos/images with negligible discrepancy with the FP16 baseline:

FP16 Baseline Quant W8A8 ViDiT-Q W8A8

For more information, please refer to our Project Page: https://a-suozhang.xyz/viditq.github.io/

Env Setup

We recommend using conda for enviornment management.

cd diffuser-dev

# create a virtual env
conda create -n viditq python=3.10
# activate virtual environment
conda activate viditq

# the xformers (opensora requires) requires torch version of 2.1.1, newest torch is not compatible
conda install pytorch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 pytorch-cuda=12.1 -c pytorch -c nvidia  # pip install also works

pip install -r t2i/requirements_pixart.txt

pip install -r t2v/requirements_opensora.txt

pip install -r t2v/requirements_qdiff.txt

# install flash attention (optional)
pip install packaging ninja
pip install flash-attn --no-build-isolation

# install xformers
pip3 install xformers==0.0.23

# install the viditq package
# containing our qdiff
pip install -e .

# install opensora
cd t2v
pip install -e .


Commands to Run

After running the following commands, the output (ckpt,generated videos) will appear in the ./logs/.

We provide the shell scripts for all process below in t2i/shell_scripts and t2v/shell_scripts. For example, run bash t2v/shell_scripts/get_calib_data.sh $GPU_ID to generate the calibration dataset.

🎬 video generation

0.0 Download and convert checkpoint of the STDiT (OpenSORA) model

Please ref doc of open-sora v1.0 for more details, we only support OpenSORA v1.0 for now, newer versions will be further supported.

the original opensora code merges the qkv linears into a linear layer with more channels, we split it into 3 layers for quantization.

python t2v/scripts/split_ckpt.py

0.1. FP16 inference

we provide the precomputed text_embeds.pth for 10 opensora example prompts in t2v/util_files, which help to avoid loading the t5 ckpts onto GPU (which takes around 1 min, and around 10 GBs of memory) . Please add --precompute_text_embeds ./t2v/utils_files/text_embeds.pth when running command.

CFG="./t2v/configs/opensora/inference/16x512x512.py"  # the opensora config
CKPT_PATH="./logs/split_ckpt/OpenSora-v1-HQ-16x512x512-split.pth"  # your path of splited ckpt
OUTDIR="./logs/fp16_inference"  # your_path_to_save_videos
GPU_ID=$1

CUDA_VISIBLE_DEVICES=$GPU_ID python t2v/scripts/inference.py $CFG --ckpt_path $CKPT_PATH  --outdir $OUTDIR \
--precompute_text_embeds ./t2v/utils_files/text_embeds.pth

1.1 Generate calib data

CFG="./t2v/configs/opensora/inference/16x512x512.py" # the opensora config
CKPT_PATH="./logs/split_ckpt/OpenSora-v1-HQ-16x512x512-split.pth"  # splited ckpt
GPU_ID=$1
CALIB_DATA_DIR="./logs/calib_data"  # the path to save your calib dataset

# quant calib data
CUDA_VISIBLE_DEVICES=$GPU_ID python t2v/scripts/get_calib_data.py $CFG --ckpt_path $CKPT_PATH --data_num 10 --outdir $CALIB_DATA_DIR --save_dir $CALIB_DATA_DIR \
--precompute_text_embeds ./t2v/utils_files/text_embeds.pth

1.2 Post Training Quantization (PTQ) Process

We provide configs for different quantizaiton techniques, for details of these configs, please refer to our paper. It's worth noting that some techniques are compatible but not applied for higher bitwidth (e.g., W8A8) for simplicity.

Plan CFG Name
ViDiT-Q W8A8 w8a8_dynamic.yaml
ViDiT-Q W6A6 w6a6_naive_cb.yaml
ViDiT-Q W4A8 w4a8_timestep_aware_cb.yaml
EXP_NAME="w8a8_naive"

CFG="./t2v/configs/quant/opensora/16x512x512.py"  # the opensora config
Q_CFG="./t2v/configs/quant/opensora/$EXP_NAME.yaml"  # TODO: the config of PTQ
CKPT_PATH="./logs/split_ckpt/OpenSora-v1-HQ-16x512x512-split.pth"  # splited ckpt generated by split_ckpt.py
CALIB_DATA_DIR="./logs/calib_data"  # your path of calib data
OUTDIR="./logs/$EXP_NAME"  # TODO: your path to save the ptq result
GPU_ID=$1

# ptq
CUDA_VISIBLE_DEVICES=$GPU_ID python t2v/scripts/ptq.py $CFG --ckpt_path $CKPT_PATH --ptq_config $Q_CFG --outdir $OUTDIR \
    --calib_data $CALIB_DATA_DIR/calib_data.pt \
    --part_fp \
    --precompute_text_embeds ./t2v/utils_files/text_embeds.pth

1.3 Quantized Model Inference

1.3.1 normal quantized inference

EXP_NAME="w8a8_naive"

CFG="./t2v/configs/quant/opensora/16x512x512.py" # the opensora config
CKPT_PATH="./logs/split_ckpt/OpenSora-v1-HQ-16x512x512-split.pth"  # your path of splited ckpt
OUTDIR="./logs/$EXP_NAME"  # your path of the w8a8 ptq result
GPU_ID=$1
# SAVE_DIR="W8A8_ptq"  # your path to save generated, leave blank to save at $OUTDIR/generated_videos

# quant inference
CUDA_VISIBLE_DEVICES=$GPU_ID python t2v/scripts/quant_txt2video.py $CFG \
    --outdir $OUTDIR --ckpt_path $CKPT_PATH  \
    --dataset_type opensora \
    --part_fp \
    --precompute_text_embeds ./t2v/utils_files/text_embeds.pth \
    # --save_dir $SAVE_DIR \

1.3.2 mixed precision quantized inference

EXP_NAME='w4a8_timestep_cb'

CFG="./t2v/configs/quant/opensora/16x512x512.py" # the opensora config
CKPT_PATH="./logs/split_ckpt/OpenSora-v1-HQ-16x512x512-split.pth"  # splited ckpt generated by split_ckpt.py
OUTDIR="./logs/$EXP_NAME"  # the path of the result of the W4A8 PTQ
GPU_ID=$1
MP_W_CONFIG="./t2v/configs/quant/W4A8_Naive_Smooth/t20_weight_4_mp.yaml"  # the mixed precision config of weight
MP_A_CONFIG="./t2v/configs/quant/W4A8_Naive_Smooth/t20_act_8_mp.yaml" # the mixed precision config of act
#SAVE_DIR="W4A8_Naive_Smooth_samples"  # leave blank to use the default path $OUTDIR/generated_videos

# quant infer
CUDA_VISIBLE_DEVICES=$GPU_ID python t2v/scripts/quant_txt2video_mp.py $CFG --outdir $OUTDIR --ckpt_path $CKPT_PATH  --dataset_type opensora \
    --part_fp\
    --timestep_wise_mp \
    --time_mp_config_weight $MP_W_CONFIG \
    --time_mp_config_act $MP_A_CONFIG \
    --precompute_text_embeds ./t2v/utils_files/text_embeds.pth \
    #--save_dir $SAVE_DIR

1.3.3. Get Sensitivity (optional)

timestep wise quant + block wise quant + group wise quant

python t2v/scripts/get_sensitivity.py $CFG --ckpt_path $CKPT_PATH --outdir $OUTDIR --save_dir $SAVE_PATH --dataset_type opensora --precompute_text_embeds $PRE_COMPUTED_TEXTEMB --part_fp \ --block_group_wise_quant --timestep_wise_quant --group_quant quant_group


<br>

## 🖼️ image generation

### 0.0 Downloading model weights
Download the corresponding model weights at the following links. For PixArt-alpha, please place the folders for the tokenizer and VAE weights under the same directory.

Model weights: \[[PixArt-alpha](https://huggingface.co/PixArt-alpha/PixArt-alpha/resolve/main/PixArt-XL-2-1024-MS.pth)\, 
                 [PixArt-sigma](https://huggingface.co/PixArt-alpha/PixArt-Sigma/blob/main/PixArt-Sigma-XL-2-1024-MS.pth)] <br>
Tokenizer and vae weights: \[PixArt-alpha: ([t5](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl),[vae](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/sd-vae-ft-ema)), [PixArt-sigma](https://huggingface.co/PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers)\] 

### 0.1 (optional) Precompute the text embedding

> The pixart model family adopts the `T5-XXL` as text encoder, which cost 10GB GPU memory, to save the GPU memory and the long model loading time, we support precompute the text embeddings. 

- `bash ./t2i/shell_scripts/get_text_embeds.sh $GPU_ID`: read the prompts from `./t2i/asset/${TXT_NAME}.txt`, and save the T5 text embeddings into a file named like`text_embeds_alpha_calib`. This file could be specified with the `--precompute_text_embeds` for further processes.
    - We provide 3 prompt lists:
        - `calib.txt`: the first 64 prompts for pixart example prompts.
        - `samples.txt`: the complete 120 prompts for pixart example prompts.
        - `coco_1024.txt`: the first 1024 prompts for coco annotations (used for evaluation). 
    - The text embeds for pixart-alpha and sigma is the same for less than 120 tokens (pixart-alpha maximum token length 120, pixart-sigma maximum token lengyh 300). 

### 0.1 FP16 Inference

- `bash ./t2v/shell_scripts/fp16_inference.sh $GPU_ID`: FP inference for image generation. 
    - configure the `--version` to choose the 'alpha' or 'sigma'
    - specify the path of computed text embeds with `--precompute_text_embeds`

### 1.1 Generate calibration data

- `bash ./t2v/shell_scripts/get_calib_data.sh $GPU_ID`: Generate the calibration data.

### 1.2 Post Training Quantization (PTQ) Process

- `bash ./t2v/shell_scripts/ptq.sh $GPU_ID`: conducting the PTQ process based on calib data, generate the quantized checkpoint.
    - the quantization configs are presented in `t2i/configs/quant/$version` folder, the `w8a8_naive.yaml` is the baseline quantization, and `w8a8.yaml` is the ViDiT-Q plan. 

### 1.3 Quantizad Inference.

- `bash ./t2v/shell_scripts/quant_inference.sh $GPU_ID`: conducting quantized model infernece. 

<br>

# Citation

If you find our work helpful, please consider citing:

@misc{zhao2024viditq, title={ViDiT-Q: Efficient and Accurate Quantization of Diffusion Transformers for Image and Video Generation}, author={Tianchen Zhao and Tongcheng Fang and Enshu Liu and Wan Rui and Widyadewi Soedarmadji and Shiyao Li and Zinan Lin and Guohao Dai and Shengen Yan and Huazhong Yang and Xuefei Ning and Yu Wang}, year={2024}, eprint={2406.02540}, archivePrefix={arXiv}, primaryClass={cs.CV} }



# Acknowledgments
Our code was developed based on [opensora v1.0](https://github.com/hpcaitech/Open-Sora)(Apache License), [PixArt-alpha](https://github.com/PixArt-alpha/PixArt-alpha)(AGPL-3.0 license), [PixArt-sigama](https://github.com/PixArt-alpha/PixArt-sigma)(AGPL-3.0 license) and [q-diffusion](https://github.com/Xiuyu-Li/q-diffusion)(MIT License)