aaron-xichen / pytorch-playground

Base pretrained models and datasets in pytorch (MNIST, SVHN, CIFAR10, CIFAR100, STL10, AlexNet, VGG16, VGG19, ResNet, Inception, SqueezeNet)
MIT License
2.62k stars 612 forks source link

the interger bits and mantissa bits after log min max quantize #9

Closed zhangying0127 closed 6 years ago

zhangying0127 commented 6 years ago

I do not understand the log min max quantize quite well. can you explain which bits are interger part and which are mantissa part after log min max quantize? How can i get the interger representation of the quantized number? Thank you so much.

aaron-xichen commented 6 years ago

The log min max method actually first scales any float tensor to [0, 1], followed with normal linear quantization.

def linear_quantize(input, sf, bits):
    assert bits >= 1, bits
    if bits == 1:
        return torch.sign(input) - 1
    delta = math.pow(2.0, -sf)
    bound = math.pow(2.0, bits-1)
    min_val = - bound
    max_val = bound - 1
    rounded = torch.floor(input / delta + 0.5)

    clipped_value = torch.clamp(rounded, min_val, max_val) * delta
    return clipped_value

torch.clamp(rounded, min_val, max_val) is the integer representation.

jeff830107 commented 6 years ago

Why do you say that the log_minmax method first scales any float tensor to [0,1] and followed with normal linear quantization? It seems that it is followed with the min_max_quantize() function, not the linear_quantize() function.

def log_minmax_quantize(input, bits):
    assert bits >= 1, bits
    if bits == 1:
        return torch.sign(input), 0.0, 0.0
    s = torch.sign(input)
    input0 = torch.log(torch.abs(input) + 1e-20)
    v = min_max_quantize(input0, bits)
    v = torch.exp(v) * s
    return v