quic / aimet

AIMET is a library that provides advanced quantization and compression techniques for trained neural network models.
https://quic.github.io/aimet-pages/index.html
Other
2.13k stars 382 forks source link

QAT extremely low with per channel #1528

Open wuguangbin1230 opened 2 years ago

wuguangbin1230 commented 2 years ago

Hi All, Can you tell me why the speed of per channel QAT is extremely slow? The details of my codes and some spent time are listed as follows.

my per channel codes:

import copy
from aimet_torch.model_preparer import prepare_model
from aimet_torch.adaround.adaround_weight import Adaround, AdaroundParameters
from aimet_common.defs import QuantScheme
from aimet_torch.batch_norm_fold import fold_all_batch_norms
from aimet_torch.model_validator.model_validator import ModelValidator
import progressbar
import torch
from torch.utils.data import DataLoader
from nets.Backbone_net import Backbone_net
from aimet_torch.quantsim import QuantParams, QuantizationSimModel
import torch.optim as optim

def pass_calibration_data(sim_model, data_loader, iterations=5, use_cuda=True):
    assert iterations > 0
    device = torch.device('cpu')
    if use_cuda:
        if torch.cuda.is_available():
            device = torch.device('cuda')
        else:
            raise RuntimeError("Found no CUDA Device while use_cuda is selected")

    sim_model = sim_model.to(device)
    sim_model.eval()

    batch_cntr = 1
    progress_bar = progressbar.ProgressBar().start()
    with torch.no_grad():
        for mini_batch_data in data_loader:
            img = mini_batch_data.cuda()
            sim_model(img)
            progress_bar.update(int(batch_cntr / iterations * 100))
            batch_cntr += 1
            if batch_cntr > iterations:
                break
    progress_bar.finish()
    return None

data_loader= Dataloader()
model = Backbone_net().cuda()
model.eval()
prepared_model = prepare_model(model)
ModelValidator.validate_model(prepared_model, model_input=dummy_input)
bn_folded_model= copy.deepcopy(prepared_model)
_ = fold_all_batch_norms(bn_folded_model, input_shapes=args_shape)
params = AdaroundParameters(data_loader=dataloader_test, num_batches=4, default_num_iterations=32,
                                default_reg_param=0.01, default_beta_range=(20, 2))
adarounded_model = Adaround.apply_adaround(bn_folded_model, dummy_input, params, path=log_dir,
                                               filename_prefix='adaround', default_param_bw=8,
                                               default_quant_scheme=QuantScheme.post_training_tf_enhanced,
                                               default_config_file='./aimet_config_perchannel.json')

quantsim = QuantizationSimModel(model=model,quant_scheme='tf_enhanced',config_file=config_file, dummy_input=dummy_input,
                                    rounding_mode='nearest',default_output_bw=8,default_param_bw=8,in_place=False)
quantsim.compute_encodings(forward_pass_callback=partial(pass_calibration_data,iterations=iterations,use_cuda=(device == 'gpu')),
                            orward_pass_callback_args=data_loader)
model = quantsim.model
optimizer = optim.Adam(model.parameters(), lr=0.1, betas=(0.9, 0.999), weight_decay=args.weight_decay)
for mini_batch_data in data_loader:
    imgs = mini_batch_data.cuda()
    model.train()
    time0 = time.time()
    loss =  model(imgs)  #若model.eval()在前,model(imgs)的运行速度则正常
    time_range = time.time() - time0
    print(f"model.train spent time of {cur_iters} inters = {time_range}")
    model.eval()
    time1 = time.time()
    loss = model(imgs) #若model(imgs) 代前面加了model.train(),它的运行速度就变的极慢
    time_range1 = time.time() - time1
    print(f"model.eval spent time of {cur_iters} inters = {time_range1}")
    loss.backward()
    optimizer.step()

Under the model.train() and model.eval(), the spent times for each forward step with batchsize 16 are listed as follows:

model.train spent time of 1 inters = 23.023080825805664
model.eval  spent time of 1 inters = 0.09899067878723145
model.train spent time of 2 inters = 21.375224113464355
model.eval spent time of 2 inters = 0.09649300575256348
model.train spent time of 3 inters = 21.304085731506348
model.eval spent time of 3 inters = 0.07633590698242188
model.train spent time of 4 inters = 21.28969645500183
model.eval spent time of 4 inters = 0.07672667503356934
model.train spent time of 5 inters = 21.269729614257812
model.eval spent time of 5 inters = 0.07581830024719238
model.train spent time of 6 inters = 21.275647401809692
model.eval spent time of 6 inters = 0.07834458351135254
model.train spent time of 7 inters = 21.266693115234375
model.eval spent time of 7 inters = 0.07533025741577148
model.train spent time of 8 inters = 21.266995191574097
model.eval spent time of 8 inters = 0.07567667961120605
model.train spent time of 9 inters = 21.257187366485596
model.eval spent time of 9 inters = 0.07504153251647949
model.train spent time of 10 inters = 21.264954090118408
model.eval spent time of 10 inters = 0.0758976936340332
model.train spent time of 11 inters = 21.27093195915222
model.eval spent time of 11 inters = 0.07504415512084961
model.train spent time of 12 inters = 21.268646955490112
model.eval spent time of 12 inters = 0.07635951042175293
model.train spent time of 13 inters = 21.25061321258545
model.eval spent time of 13 inters = 0.07557082176208496
model.train spent time of 14 inters = 21.253055095672607
model.eval spent time of 14 inters = 0.07657814025878906
model.train spent time of 15 inters = 21.260530710220337
model.eval spent time of 15 inters = 0.0758061408996582

contents of aimet_config_perchannel.json are:

{"defaults": {
    "ops": {
        "is_output_quantized": "True",
        "is_symmetric": "False"
    },
    "params": {
        "is_quantized": "True",
        "is_symmetric": "False"
    },
    "strict_symmetric": "False",
    "unsigned_symmetric": "True",
    "per_channel_quantization": "True"
},
"params": {
    "bias": {
        "is_quantized": "True"
    }
},
"op_type": {},
"supergroups": [
    {
        "op_list": ["Conv", "Relu"]
    },
    {
        "op_list": ["Conv", "Clip"]
    },
    {
        "op_list": ["Add", "Relu"]
    },
    {
        "op_list": ["Gemm", "Relu"]
    }
],
"model_input": {
    "is_input_quantized": "True"
},
"model_output": {}}
quic-mangal commented 1 year ago

@wuguangbin1230, sorry for the late reply. We are currently working on speed up/enhancements for QAT per channel. Will keep you posted when those changes are complete