nbasyl / OFQ

The official implementation of the ICML 2023 paper OFQ-ViT
MIT License
27 stars 0 forks source link
icml icml2023 model-compression model-compression-papers model-quantization quantization-awar vision-transformer vision-transformers

OFQ: Oscillation-free Quantization for Low-bit Vision Transformers

This repository contains the training code of ViT introduced in our work: "Oscillation-free Quantization for Low-bit Vision Transformers" which has been accepted for ICML 2023. Please consider starring the repo if you find our work useful, thanks!

In this work, we discusses the issue of weight oscillation in quantization-aware training and how it negatively affects model performance. The learnable scaling factor, commonly used in quantization, was found to worsen weight oscillation. The study proposes three techniques to address this issue: statistical weight quantization (StatsQ), confidence-guided annealing (CGA), and query-key reparameterization (QKR). These techniques were tested on the ViT model and were found to improve quantization robustness and accuracy. The proposed 2-bit DeiT-T/DeiT-S algorithms outperform the previous state-of-the-art by 9.8% and 7.7%, respectively.

Fig.1 - Trajectory of statistical scaling factors (StatsQ) from the 10th transformer blocks in a 2-bit DeiT-S throughout CGA with 4 different boundary ranges [BR_0.003, BR_0.005, BR_0.007, BR_0.01]. The y-axis represents the value of scaling factorsß

Run

1. Requirements:

Please replace "/your/miniconda3/envs/ofq/lib/python3.8/site-packages/timm/data/dataset_factory.py" with "timm_fix_imagenet_loading_bugs/dataset_factory.py" as with the original code there is a "TypeError: init() got an unexpected keyword argument 'download'" error.

2. Data:

3. Pretrained models:

4. Steps to run:

Models

1. ImageNet1K dataset

Models #Bits Top-1 Accuracy (Model Link) eval script
DeiT-T 32-32 72.02 -------
OFQ DeiT-T 2-2 64.33 eval_scripts/deit_t/w2a2.sh
OFQ DeiT-T 3-3 72.72 eval_scripts/deit_t/w3a3.sh
OFQ DeiT-T 4-4 75.46 eval_scripts/deit_t/w4a4.sh
DeiT-S 32-32 79.9 -------
OFQ DeiT-S 2-2 75.72 eval_scripts/deit_s/w2a2.sh
OFQ DeiT-S 3-3 79.57 eval_scripts/deit_s/w3a3.sh
OFQ DeiT-S 4-4 81.10 eval_scripts/deit_s/w4a4.sh
Swin-T 32-32 81.2 -------
OFQ Swin-T 2-2 78.52 eval_scripts/swin_t/w2a2.sh
OFQ Swin-T 3-3 81.09 eval_scripts/swin_t/w3a3.sh
OFQ Swin-T 4-4 81.88 eval_scripts/swin_t/w4a4.sh

Acknowledgement

The original code is borrowed from DeiT.

Citation

If you find our code useful for your research, please consider citing:

@InProceedings{pmlr-v202-liu23w,
  title =    {Oscillation-free Quantization for Low-bit Vision Transformers},
  author =       {Liu, Shih-Yang and Liu, Zechun and Cheng, Kwang-Ting},
  booktitle =    {Proceedings of the 40th International Conference on Machine Learning},
  pages =    {21813--21824},
  year =     {2023},
  editor =   {Krause, Andreas and Brunskill, Emma and Cho, Kyunghyun and Engelhardt, Barbara and Sabato, Sivan and Scarlett, Jonathan},
  volume =   {202},
  series =   {Proceedings of Machine Learning Research},
  month =    {23--29 Jul},
  publisher =    {PMLR},
  pdf =      {https://proceedings.mlr.press/v202/liu23w/liu23w.pdf},
  url =      {https://proceedings.mlr.press/v202/liu23w.html},
  abstract =     {Weight oscillation is a by-product of quantization-aware training, in which quantized weights frequently jump between two quantized levels, resulting in training instability and a sub-optimal final model. We discover that the learnable scaling factor, a widely-used $\textit{de facto}$ setting in quantization aggravates weight oscillation. In this work, we investigate the connection between learnable scaling factor and quantized weight oscillation using ViT, and we additionally find that the interdependence between quantized weights in $\textit{query}$ and $\textit{key}$ of a self-attention layer also makes ViT vulnerable to oscillation. We propose three techniques correspondingly: statistical weight quantization ($\rm StatsQ$) to improve quantization robustness compared to the prevalent learnable-scale-based method; confidence-guided annealing ($\rm CGA$) that freezes the weights with $\textit{high confidence}$ and calms the oscillating weights; and $\textit{query}$-$\textit{key}$ reparameterization ($\rm QKR$) to resolve the query-key intertwined oscillation and mitigate the resulting gradient misestimation. Extensive experiments demonstrate that our algorithms successfully abate weight oscillation and consistently achieve substantial accuracy improvement on ImageNet. Specifically, our 2-bit DeiT-T/DeiT-S surpass the previous state-of-the-art by 9.8% and 7.7%, respectively. The code is included in the supplementary material and will be released.}
}

Contact

Shih-Yang Liu, HKUST (sliuau at connect.ust.hk)