ZouJiu1 / LSQplus

LSQ+ or LSQplus
57 stars 14 forks source link

LSQplus 论文初始化方法实现的疑问 #8

Closed TaooCAI closed 1 year ago

TaooCAI commented 1 year ago
def update_LSQplus_activation_Scalebeta(model):
    for name, child in model.named_children():
        if isinstance(child, (QuantConv2d, QuantConvTranspose2d, QuantLinear)):
            weight = child.weight.data
            s = child.activation_quantizer.s.data
            beta = child.activation_quantizer.beta.data
            Qn = child.activation_quantizer.Qn
            Qp = child.activation_quantizer.Qp
            g = child.activation_quantizer.g
            # print('before: ', name, child.activation_quantizer.s.grad.data, child.activation_quantizer.beta.grad.data, s, beta)
            q_w = (weight - beta) / s
            # print(q_w)
            smaller = (q_w < Qn).float() #bool值转浮点值,1.0或者0.0
            bigger = (q_w > Qp).float() #bool值转浮点值,1.0或者0.0
            between = 1.0 - smaller -bigger #得到位于量化区间的index
            grad_alpha = ((smaller * Qn + bigger * Qp + 
                           between * Round.apply(q_w) - between * q_w) * g).sum().unsqueeze(dim=0)
            grad_beta = ((smaller + bigger) * g).sum().unsqueeze(dim=0)
            # print('grad_beta: ',grad_beta,g, smaller.sum(), bigger.sum(), between.sum(),Qn, Qp)
            child.activation_quantizer.s.grad.data.add_(g*(2*(child.quant_input-child.input)*grad_alpha).sum().unsqueeze(dim=0))
            child.activation_quantizer.beta.grad.data.add_(g*(2*(child.quant_input-child.input)*grad_beta).sum().unsqueeze(dim=0))
ZouJiu1 commented 1 year ago

I fix it, you can try it again, https://github.com/ZouJiu1/LSQplus/blob/master/quantization/lsqplus_quantize_V1.py#L210-L238

TaooCAI commented 1 year ago

Thanks. One more question. Here why we need to multiply g twice, one for (quant_input - input)**2 and one for grad_alpha? I found that it is bad if we use(child.quant_input - child.input) * grad_alpha

ZouJiu1 commented 1 year ago

these two lines, g can be removed or keeped, it will balance or slow down the learning_rate of those params. https://github.com/ZouJiu1/LSQplus/blob/master/quantization/lsqplus_quantize_V1.py#L229-L230