aaron-xichen / pytorch-playground

Base pretrained models and datasets in pytorch (MNIST, SVHN, CIFAR10, CIFAR100, STL10, AlexNet, VGG16, VGG19, ResNet, Inception, SqueezeNet)
MIT License
2.62k stars 612 forks source link

Quantization of skip connections #7

Closed ptillet closed 6 years ago

ptillet commented 6 years ago

I think that there may be a problem with the way the package handles skip layers such as those found in ResNets. My understanding is that the residual mapping F(x) + x seems to be quantized into: Quantize(F(Quantize(x, s1)), s2) + Quantize(x, s1) and ends up adding tensors that reside on two different scales. If I'm right, then the correctness of the output should still be fine (because the activations are re-scaled back to 2^sf and the computations carried out in float), but it would mean that the residual computation is done in more than the desired precision.

aaron-xichen commented 6 years ago

I think you are right. More accurate simulation should be Quantize(Quantize(F(Quantize(x, s1)), s2) + Quantize(x, s1), s3). And I found that the original (also wrong) quantization method made little impact on the accuracy when bit-width is not less than 8 bit. However, when bit-width goes to 6 bit, these two methods have much difference. The new (also should be correct) method results in more accuracy drop. Please checkout branch fix_issue#7 for more details. Really appreciate what you found, you have a deep understanding of the process of quantization.

ptillet commented 6 years ago

Yes, that's right. I am building very specialized INT8 GPU kernels for quantization, and ran into the same issue :) When I tried to fix it myself, I also got more accuracy drop in some cases, which makes no sense. I am still investigating on my side.