yhhhli / APoT_Quantization

PyTorch implementation for the APoT quantization (ICLR 2020)
258 stars 51 forks source link

weight_visual

APoT Quantization

@inproceedings{
Li2020Additive,
title={Additive Powers-of-Two Quantization: An Efficient Non-uniform Discretization for Neural Networks},
author={Yuhang Li and Xin Dong and Wei Wang},
booktitle={International Conference on Learning Representations},
year={2020},
url={https://openreview.net/forum?id=BkgXT24tDS}
}

This repo contains the code and data of the following paper accepeted by ICLR 2020

Additive Power-of-Two Quantization: An Efficient Non-uniform Discretization For Neural Networks

quantize_function

Updates

Installation

Prerequisites

Pytorch 1.1.0 with CUDA

Dataset Preparation

ImageNet

models.quant_layer.py contains the configuration for quantization. In particular, you can specify them in the class QuantConv2d:

class QuantConv2d(nn.Conv2d):
    """Generates quantized convolutional layers.

    args:
        bit(int): bitwidth for the quantization,
        power(bool): (A)PoT or Uniform quantization
        additive(float): Use additive or vanilla PoT quantization

    procedure:
        1. determine if the bitwidth is illegal
        2. if using PoT quantization, then build projection set. (For 2-bit weights quantization, PoT = Uniform)
        3. generate the clipping thresholds

    forward:
        1. if bit = 32(full precision), call normal convolution
        2. if not, first normalize the weights and then quantize the weights and activations
        3. if bit = 2, apply calibrated gradients uniform quantization to weights
    """

    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False, bit=5, power=True, additive=True, grad_scale=None):
        super(QuantConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
        self.layer_type = 'QuantConv2d'
        self.bit = bit
        self.power = power
        self.grad_scale = grad_scale
        if power:
            if self.bit > 2:
                self.proj_set_weight = build_power_value(B=self.bit-1, additive=additive)
            self.proj_set_act = build_power_value(B=self.bit, additive=additive)
        self.act_alpha = torch.nn.Parameter(torch.tensor(6.0))
        self.weight_alpha = torch.nn.Parameter(torch.tensor(3.0))

Here, self.bit controls the bitwidth; power=True means we use PoT or APoT (use additive to specify). build_power_value construct the levels set Q^a(1, b) with parameter bit and additive. If power=False, the conv layer will adopt uniform quantization.

To train a 5-bit model, just run main.py:

python main.py -a resnet18 --bit 5

Progressive initialization requires checkpoint of higher bitwidth. For example

python main.py -a resnet18 --bit 4 --pretrained checkpoint/res18_5best.pth.tar

We provide a function show_params() to print the clipping parameter in both weights and activations

Results and Checkpoints

Checkpoints are released in Google Drive.

Model Precision Hyper-Params Accuracy Checkpoints
ResNet-18 5-bit batch1k_lr0.01_wd0.0001_100epoch 70.75 res18_5bit
ResNet-18 4-bit batch1k_lr0.01_wd0.0001_100epoch 70.74 res18_4bit
ResNet-18 3-bit batch1k_lr0.01_wd0.0001_100epoch 69.79 res18_3bit
ResNet-18 2-bit batch1k_lr0.04_wd0.00002_100epoch_cg 66.46 res18_2bit
ResNet-34 5-bit batch1k_lr0.1_wd0.0001_100epoch 74.26 res34_5bit
ResNet-34 4-bit batch1k_lr0.1_wd0.0001_100epoch 74.12 res34_4bit
ResNet-34 3-bit batch1k_lr0.1_wd0.0001_100epoch 73.55 res34_3bit
ResNet-34 2-bit batch1k_lr0.1_wd0.00002_100epoch_cg 71.30 res34_2bit
ResNet-50 4-bit batch512_lr0.05_wd0.0001_100epoch 76.80 Updating
ResNet-50 3-bit batch512_lr0.05_wd0.0001_100epoch 75.92 Updating
ResNet-50 2-bit batch512_lr0.05_wd0.00025_100epoch_cg - Updating

Compared with Uniform Quantization

Use power=False to switch to the uniform quantization, results:

Model Precision Hyper-Params Accuracy Compared with APoT
ResNet-18 4-bit batch1k_lr0.01_wd0.0001_100epoch 70.54 -0.2
ResNet-18 3-bit batch1k_lr0.01_wd0.0001_100epoch 69.57 -0.22
ResNet-18 2-bit batch1k_lr0.01_wd0.00002_100epoch - Updating

Training and Validation Curve

cd $PATH-TO-THIS-PROJECT/ImageNet/events
tensorboard --logdir 'res18' --port 6006

logs

Hyper-Parameter Exploration

To be updated

CIFAR10

(CIFAR10 codes will be updated soon.)

The training code is inspired by pytorch-cifar-code from junyuseu.

The dataset can be downloaded automatically using torchvision. We provide the shell script to progressively train full precision, 4, 3, and 2 bit models. For example, train_res20.sh :

#!/usr/bin/env bash
python main.py --arch res20 --bit 32 -id 0,1 --wd 5e-4
python main.py --arch res20 --bit 4 -id 0,1 --wd 1e-4  --lr 4e-2 \
        --init result/res20_32bit/model_best.pth.tar
python main.py --arch res20 --bit 3 -id 0,1 --wd 1e-4  --lr 4e-2 \
        --init result/res20_4bit/model_best.pth.tar
python main.py --arch res20 --bit 2 -id 0,1 --wd 3e-5  --lr 4e-2 \
        --init result/res20_3bit/model_best.pth.tar

The checkpoint models for CIFAR10 are released:

Model Precision Accuracy Checkpoints
Res20 Full Precision 92.96 Res20_32bit
Res20 4-bit 92.45 Res20_4bit
Res20 3-bit 92.49 Res20_3bit
Res20 2-bit 90.96 Res20_2bit
Res56 Full Precision 94.46 Res56_32bit
Res56 4-bit 93.93 Res56_4bit
Res56 3-bit 93.77 Res56_3bit
Res56 2-bit 93.05 Res56_2bit

To evluate the models, you can run

python main.py -e --init result/res20_3bit/model_best.pth.tar -e -id 0 --bit 3

And you will get the output of accuracy and the value of clipping threshold in weights & acts:

Test: [0/100]   Time 0.221 (0.221)      Loss 0.2144 (0.2144)    Prec 96.000% (96.000%)
 * Prec 92.510%
clipping threshold weight alpha: 1.569000, activation alpha: 1.438000
clipping threshold weight alpha: 1.278000, activation alpha: 0.966000
clipping threshold weight alpha: 1.607000, activation alpha: 1.293000
clipping threshold weight alpha: 1.426000, activation alpha: 1.055000
clipping threshold weight alpha: 1.364000, activation alpha: 1.720000
clipping threshold weight alpha: 1.511000, activation alpha: 1.434000
clipping threshold weight alpha: 1.600000, activation alpha: 2.204000
clipping threshold weight alpha: 1.552000, activation alpha: 1.530000
clipping threshold weight alpha: 0.934000, activation alpha: 1.939000
clipping threshold weight alpha: 1.427000, activation alpha: 2.232000
clipping threshold weight alpha: 1.463000, activation alpha: 1.371000
clipping threshold weight alpha: 1.440000, activation alpha: 2.432000
clipping threshold weight alpha: 1.560000, activation alpha: 1.475000
clipping threshold weight alpha: 1.605000, activation alpha: 2.462000
clipping threshold weight alpha: 1.436000, activation alpha: 1.619000
clipping threshold weight alpha: 1.292000, activation alpha: 2.147000
clipping threshold weight alpha: 1.423000, activation alpha: 2.329000
clipping threshold weight alpha: 1.428000, activation alpha: 1.551000
clipping threshold weight alpha: 1.322000, activation alpha: 2.574000
clipping threshold weight alpha: 1.687000, activation alpha: 1.314000

[Slides]

image-20200516213830422