megvii-research / FQ-ViT

[IJCAI 2022] FQ-ViT: Post-Training Quantization for Fully Quantized Vision Transformer
Apache License 2.0
301 stars 48 forks source link

关于QINTLayerNorm和ptf中的计算问题 #34

Closed caoliyi closed 1 year ago

caoliyi commented 1 year ago

根据论文中关于ptf的公式 (9)、(10) image 按照我的理解S应该对应代码中的scale1,但是代码中计算zero_point中的S对应的是scale8,请问那个公式是正确的?

def get_quantization_params(self, inputs, *args, **kwargs):
        max_val = self.max_val
        min_val = self.min_val

        qmax = self.bit_type.upper_bound
        qmin = self.bit_type.lower_bound

        best_score = 1e+10
        max_val_t = max_val.max()
        min_val_t = min_val.min()
        scale8 = (max_val_t - min_val_t) / float(qmax - qmin)
        scale8.clamp_(self.eps)
        scale4 = scale8 / 2
        scale2 = scale4 / 2
        scale1 = scale2 / 2
        zero_point = qmin - torch.round(min_val_t / scale8)
        zero_point.clamp_(qmin, qmax)
        scale_mask = torch.ones_like(max_val)
        for j in range(inputs.shape[2]):
            data = inputs[..., j].unsqueeze(-1)
            data_q1 = ((data / scale1 + zero_point).round().clamp(qmin, qmax) -
                       zero_point) * scale1
            data_q2 = ((data / scale2 + zero_point).round().clamp(qmin, qmax) -
                       zero_point) * scale2
            data_q4 = ((data / scale4 + zero_point).round().clamp(qmin, qmax) -
                       zero_point) * scale4
            data_q8 = ((data / scale8 + zero_point).round().clamp(qmin, qmax) -
                       zero_point) * scale8
            score1 = lp_loss(data, data_q1, p=2.0, reduction='all')
            score2 = lp_loss(data, data_q2, p=2.0, reduction='all')
            score4 = lp_loss(data, data_q4, p=2.0, reduction='all')
            score8 = lp_loss(data, data_q8, p=2.0, reduction='all')
            score = [score1, score2, score4, score8]
            scale_mask[j] *= 2**score.index(min(score))
        scale = scale1 * scale_mask
        return scale, zero_point

另外代码中关于QINTLayerNorm中的计算,按照论文中(24)、(25)所提供的公式, s应该对应原始的in_scale8, 但计算中使用的是in_scale1,请问s的正确表达。

关于x_q的计算:

x_q = (x / in_scale).round()
            in_scale1 = in_scale.min()
            in_scale_mask = (in_scale / in_scale1).round()

            x_q = x_q * in_scale_mask

并没有计算zero_point,请问是为什么?

关于A的计算这里使用了 A_sign = A.sign() 在论文中没有找到对应的公式,请问为什么要这样计算?

十分感谢

linyang-zhh commented 1 year ago
  1. Eq (9) 中s代表scale1;Eq. (10) 书写存在问题,此处的s应为scale8(即 $s \times 2^K$),感谢纠正,我们之后会对论文进行修改;
  2. Eq. (24), (25)中的s都是scale1,如Eq (9) 所述;
  3. “并没有计算zero_point”,如Eq. (22), (23), zero_point只用来平移,不影响放缩,可以理解成 (x / in_scale + zero_point) - zero_point
  4. 关于A_sign = A.sign(),对应Eq. (29-32)部分,但对其进行了省略,公式中只考虑了A为正的情况,而代码是正负均可,请以代码为准,后续我们会对论文进行修改。