deJQK / AdaBits

38 stars 6 forks source link

Bad performance on CIFAR using on low bit width #3

Open Ahmad-Jarrar opened 1 year ago

Ahmad-Jarrar commented 1 year ago

I am trying to run your experiments on CIFAR10 as described in the q_resnet_uint8_train_val.yml . However i am getting poor performance on lower bit widths. I have tried with several tweaks to the config file. The result of the latest experiment is:

image

I have used these parameters:

# =========================== Basic Settings ===========================
# machine info
num_gpus_per_job: 4  # number of gpus each job need
num_cpus_per_job: 63  # number of cpus each job need
memory_per_job: 200  # memory requirement each job need
gpu_type: "nvidia-tesla-v100"

# data
dataset: CIFAR10
data_transforms: cifar
data_loader: cifar
dataset_dir: ./data
data_loader_workers: 5 #10

# info
num_classes: 10
image_size: 32
topk: [1, 5]
num_epochs: 200 #150

# optimizer
optimizer: sgd
momentum: 0.9
weight_decay: 0.00004
nesterov: True

# lr
lr: 0.1 #0.05
lr_scheduler: multistep
multistep_lr_milestones: [100, 150]
multistep_lr_gamma: 0.1
#lr_scheduler: cos_annealing_iter
#lr_scheduler: butterworth_iter #mixed_iter #gaussian_iter #exp_decaying_iter #cos_annealing_iter
#exp_decaying_gamma: 0.98

# model profiling
profiling: [gpu]
#model_profiling_verbose: True

# pretrain, resume, test_only
pretrained_dir: ''
pretrained_file: ''
resume: ''
test_only: False

#
random_seed: 1995
batch_size: 128 #1024 #256 #512 #256 #1024 #4096 #1024 #256
model: ''
reset_parameters: True

#
distributed: False #True
distributed_all_reduce: False #True
use_diff_seed: False #True

#
stats_sharing: False

#
#unbiased: False
clamp: True
rescale: True #False
rescale_conv: True #False
switchbn: True #False
#normalize: False
bn_calib: True
rescale_type: constant #[stddev, constant]

#
pact_fp: False
switch_alpha: True

#
weight_quant_scheme: modified
act_quant_scheme: original

# =========================== Override Settings ===========================
#fp_pretrained_file: /path/to/best_model.pt
log_dir: ./results/cifar10/resnet20
adaptive_training: True
model: models.q_resnet_cifar
depth: 20
bits_list: [4,3,2]
weight_only: False

Kindly let me know how can I improve the results and what am I doing wrong.

deJQK commented 1 year ago

Thanks for your interests in our work. Could you please try to use [3, 4, 5] to see if there is still this issue? Also, what is the performance of [2]?

Ahmad-Jarrar commented 1 year ago

I have again tested using the config file provided. Only the bit-widths were changed.

image

You can see if I run on 5, 4, 3 bits the performance is fine. But if i do it with 4, 3, 2 the performance is much worse on all bit levels.

deJQK commented 1 year ago

Hi @Ahmad-Jarrar , sorry for this, the quantization scheme proposed in the paper does not converge for low bits, and some modification is necessary. I remembered I posted this... For proper convergence, it should be better to have vanishing mean for weights, besides proper variance requirements. To this end, you should use the following quantization method:

$$ q = 2 \cdot \frac{1}{2^b}\Bigg(\mathrm{clip}\bigg(\Big\lfloor 2^b\cdot\frac{w+1}{2} \Big\rfloor,0,2^b-1\bigg)+\frac{1}{2}\Bigg) - 1 $$

This will guarantee centered distribution for weights.

The code is something like this:

a = 1 << bit
res = torch.floor(a * input)
res = torch.clamp(res, max=a - 1)
res.add_(0.5)
res.div_(a)

inside the q_k function.

You could also try this for activation quantization (without applying the outermost remapping 2x-1), but I did not try this before.

I will update the code and readme accordingly.

Best.

deJQK commented 1 year ago

Hi @Ahmad-Jarrar , I have updated the readme. Hope it is clear. Thanks again for your interest in our work.

Ahmad-Jarrar commented 1 year ago

If I'm not wrong, the code given does not apply the outermost 2x-1.

deJQK commented 1 year ago

If I'm not wrong, the code given does not apply the outermost 2x-1.

https://github.com/deJQK/AdaBits/blob/master/models/quant_ops.py#L142-L143

Ahmad-Jarrar commented 1 year ago

Yes I noticed it later. Thank you so much for your help.

haiduo commented 1 year ago

Hi @Ahmad-Jarrar , sorry for this, the quantization scheme proposed in the paper does not converge for low bits, and some modification is necessary. I remembered I posted this... For proper convergence, it should be better to have vanishing mean for weights, besides proper variance requirements. To this end, you should use the following quantization method:

q=2⋅12b(clip(⌊2b⋅w+12⌋,0,2b−1)+12)−1

This will guarantee centered distribution for weights.

The code is something like this:

a = 1 << bit
res = torch.floor(a * input)
res = torch.clamp(res, max=a - 1)
res.add_(0.5)
res.div_(a)

inside the q_k function.

You could also try this for activation quantization (without applying the outermost remapping 2x-1), but I did not try this before.

I will update the code and readme accordingly.

Best.

Hello @deJQK , I can't understand the meaning that "For proper convergence, it should be better to have vanishing mean for weights, besides proper variance requirements.". Why is "For proper convergence, it should be better to vanishing mean for weights"? Could you give me a specific explanation? Additionally, this formula isn't match code about: image

Looking forward to your reply, thank you.

deJQK commented 1 year ago

Hi @haiduo , you could check these papers: https://arxiv.org/pdf/1502.01852.pdf, https://arxiv.org/pdf/1606.05340.pdf, https://arxiv.org/pdf/1611.01232.pdf, all of which analyze training dynamics for centered weight. I am not sure how to analyze weights with nonzero mean.

For +1 and -1, please check here and here.

haiduo commented 1 year ago

Hi @haiduo , you could check these papers: https://arxiv.org/pdf/1502.01852.pdf, https://arxiv.org/pdf/1606.05340.pdf, https://arxiv.org/pdf/1611.01232.pdf, all of which analyze training dynamics for centered weight. I am not sure how to analyze weights with nonzero mean.

For +1 and -1, please check here and here.

Thank you for your reply! @deJQK , So "vanishing mean for weights" just added 0.5 after q_k function, and everything else is the same, right? I can interpret this as converting [-1, 1] to [0,1] to [0, 15] to [0.5, 15.5] to [-0.9375, 0.9375] for b=4, does it correspond to the third picture below "Centered Symmetric"? image It doesn't seem right, I feel confused. If it is convenient, could you please send me the code for the above four diagrams? Maybe I'll understand soon. Thank you very much! You can send me an email 'huanghd@stu.xjtu.edu.cn', I am very interested in your work.

deJQK commented 1 year ago

Hi @haiduo, thanks again for your interest. For b=4, it maps [-1, 1] to [0, 1], to {0, 1, ..., 15}, to {0.5, 1.5, ..., 15.5}, to {1/32, 3/32, ..., 31/32}, to {-15/16, -13/16, ..., 13/16, 15/16}. Code for all four schemes is available in the repo and you could check the related lines.

haiduo commented 1 year ago

Hi @haiduo, thanks again for your interest. For b=4, it maps [-1, 1] to [0, 1], to {0, 1, ..., 15}, to {0.5, 1.5, ..., 15.5}, to {1/32, 3/32, ..., 31/32}, to {-15/16, -13/16, ..., 13/16, 15/16}. Code for all four schemes is available in the repo and you could check the related lines.

ok,Thank you!

haiduo commented 1 year ago

Hi @haiduo, thanks again for your interest. For b=4, it maps [-1, 1] to [0, 1], to {0, 1, ..., 15}, to {0.5, 1.5, ..., 15.5}, to {1/32, 3/32, ..., 31/32}, to {-15/16, -13/16, ..., 13/16, 15/16}. Code for all four schemes is available in the repo and you could check the related lines.

Hi @deJQK , Sorry, one more question, I need you to answer two of my questions about:

  1. So "vanishing mean for weights" just added 0.5 after the q_k function, and everything else is the same, right?
  2. For b=4, it maps [-1, 1] to [0, 1], to {0, 1, ..., 15}, to {0.5, 1.5, ..., 15.5}, to {1/32, 3/32, ..., 31/32}, to {-15/16, -13/16, ..., 13/16, 15/16}, Is it corresponds to the third picture below "Centered Symmetric"?
deJQK commented 1 year ago

@haiduo, yes for both.