yanghr / BSQ

BSQ: Exploring Bit-Level Sparsity for Mixed-Precision Neural Network Quantization (ICLR 2021)
https://openreview.net/forum?id=TiXl51SCNw8
Apache License 2.0
36 stars 9 forks source link

CIFAR-10 experiments

This folder contains the code for inducing mixed precision quantization schemes with BSQ on the CIFAR-10 dataset. The code for ResNet models are configured into bit representation, so as to support BSQ training and to achieve the results in the main paper.

Acknowledgement

The training and evaluation codes and the model architectures are adapted from bearpaw/pytorch-classification.

Install

Clone recursively:

git clone --recursive https://github.com/yanghr/BSQ.git

Specification of dependencies

This code is tested with Python 3.6.8, PyTorch 1.2.0 and TorchVision 0.4.0. It is recommanded to use the provided spec-file.txt file to replicate the anaconda environment used for testing this code, which can be done by:

conda create --name myenv --file spec-file.txt

We suggest using GPU to run this code for the best efficiency. Both running on a single GPU or running in parallel on multiple GPUs are supported.

Usage

Pretrained models

As introduced in Appendix A.1, pretrained models are used to initiate the BSQ training. The pretrained model are provided in the \checkpoints\cifar10\ folder, where the checkpoint in resnet-20\ is the full-precision pretrained model and the checkpoint in resnet-20-8\ is the 8-bit quantized model in bit representation.

For more details on training the full-precision model please see the training recipes provided by bearpaw/pytorch-classification. The quantized model is achieved with convert.py, which will be introduced later.

BSQ training

Here we perform BSQ training on the ResNet-20 model on the CIFAR-10 dataset.

python cifar_prune_STE.py -a resnet --depth 20 --epochs 350 --lr 0.1 --schedule 250 --gamma 0.1 --wd 1e-4 --model checkpoints/cifar10/resnet-20-8/model_best.pth.tar --decay 0.01 --Prun_Int 100 --thre 0.0 --checkpoint checkpoints/cifar10/xxx --Nbits 8 --act 4 --bin --L1 >xxx.txt

xxx in the command should be replaced with the folder you want for saving the achieved model. The achieved model will be saved in bit representation. We suggest redirecting the print output to a txt file with >xxx.txt to avoid messing up with the progress bar display and keep record of the training process.

--decay is used to set the regularization strength $$\alpha$$ in Equation (5), so as to explore the accuracy-model size tradeoff. Results for using different $$\alpha$$ are shown in Section 4.2.

--Prun_Int is the number of epochs between each re-quantization and precision adjustment step, which is suggested to be set to 100. The effect of using other intervals are illustrated in Appendix B.1.

--act indicates the quantization precision of the activation in the model. Should be kept the same in BSQ training and finetuning. Default value set to 4.

Evaluating and finetuning achieved model

The model achieved from BSQ training can be evaluated and finetuned with cifar_finetune.py.

For evaluation, run

python cifar_finetune.py -a resnet --depth 20 --model checkpoints/cifar10/xxx/checkpoint.pth.tar --Nbits 8 --act 4 --bin --evaluate

xxx in the command should be replaced with the folder used to save the BSQ trained model. Note that only model in bit representation can be evaluated in this way. The testing accuracy, the precentage of 1s in each bit of each layer's weight and the precision assigned to each layer will be printed in the output.

To further finetune the ahcieved model, use

python cifar_finetune.py -a resnet --depth 20 --epochs 300 --lr 0.01 --schedule 150 250 --gamma 0.1 --wd 1e-4 --model checkpoints/cifar10/xxx/checkpoint.pth.tar --checkpoint checkpoints/cifar10/xxx-ft --Nbits 8 --act 4 --bin >xxx-ft.txt

The quantization scheme will be fixed throughout the finetuning process. At the end of finetuning, the model with the highest testing accuracy will be stored in both bit representation and floating-point weights. The bit representation is saved in checkpoints/cifar10/xxx-ft/best_bin.pth.tar and the floating-point model is saved in checkpoints/cifar10/xxx-ft/best_float.pth.tar

Converting full-precision models to bit representation with achieved quantization schemes

To convert a full-precision model, use

python convert.py -a resnet --depth 20 --model checkpoints/cifar10/resnet-20/model_best.pth.tar --dict checkpoints/cifar10/xxx/checkpoint.pth.tar --checkpoint checkpoints/cifar10/xxx-mp --Nbits 8 --act 4 >xxx-mp.txt

If the path in --dict is provided, the model will be converted to the same quantization scheme as the model specified in --dict. Otherwise the whole model will be quantized to the precision specified in --Nbits. The converted model will be in bit representation, and will be saved in the folder specified in --checkpoint. We use this code to achieve the 8-bit quantized model before BSQ training, and to achieve the "train from scratch" models that are further finetuned to be compared in Table 1.