Jermmy / pytorch-quantization-demo

A simple network quantization demo using pytorch from scratch.
Apache License 2.0
481 stars 94 forks source link

当zero point超出qmin, qmax范围时是否应该扩展原范围(rmin, rmax) #5

Closed Usigned closed 3 years ago

Usigned commented 3 years ago

我将一个[1, 2]的向量进行8位量化之后反量化得到[1, 1],精度损失很大,如果把calcScaleZeroPoint函数改成下面这样效果就会好很多:

def calcScaleZeroPoint(rmin, rmax, num_bits=8):
    qmin = 0
    qmax = 2 ** num_bits -1
    scale = float((rmax - rmin) / (qmax - qmin))

    zero_point = qmax - rmax / scale

    #when out of range, then recalc scale
    if zero_point < qmin:
        zero_point = qmin
        scale = float((rmax - 0) / (qmax - qmin))
    elif zero_point > qmax:
        zero_point = qmax
        scale = float((0 - rmin) / (qmax - qmin))

    zero_point = int(zero_point)

    return scale, zero_point
Usigned commented 3 years ago

更新了QParam中的update方法后不会遇到这个问题了