megvii-research / FQ-ViT

[IJCAI 2022] FQ-ViT: Post-Training Quantization for Fully Quantized Vision Transformer
Apache License 2.0
301 stars 48 forks source link

QIntLayerNorm 中 Get_MN中的bit #33

Closed caoliyi closed 1 year ago

caoliyi commented 1 year ago

想请问下get_MN中bit设置为7的原因,这个bit数与什么有关?
以及想问下N中的取值为范围什么是(0,31)谢谢

 def get_MN(self, x):
        bit = 7
        N = torch.clamp(bit - torch.floor(torch.log2(x)), 0, 31)
        M = torch.clamp(torch.floor(x * torch.pow(2, N)), 0, 2**(bit + 1) - 1)
        return M, N
linyang-zhh commented 1 year ago

此处代码最好修改成如下 [已修改, at 9905bf ]

def get_MN(x):
    bit = 8
    N = torch.clamp(bit - 1 - torch.floor(torch.log2(x)), 0, 31)
    M = torch.clamp(torch.floor(x * torch.pow(2, N)), 0, 2**bit - 1)
    return M, N
  1. 通过bit参数确保M是一个uint8(0-255)的数值;
  2. N的范围需在0-31之间,以避免移位溢出

此组公式只是为了快速获得一个“较好”的M与N,另一种较慢但更准确的方法是:从0-31遍历N,求出对应的M及模拟误差,之后选择最优的一组M与N。

caoliyi commented 1 year ago

不好意思,我还是没有太理解数字为什么N的范围需要在0-31来避免溢出,这个部分不是8bit计算的吗?

linyang-zhh commented 1 year ago

8bit卷积后的结果通常保存为int32,之后通过requant(此处会使用到M和N)来转换为int8