pytorch / TensorRT

PyTorch/TorchScript/FX compiler for NVIDIA GPUs using TensorRT
https://pytorch.org/TensorRT
BSD 3-Clause "New" or "Revised" License
2.47k stars 340 forks source link

Int8 much slower than half precision #93

Closed tsaizhenling closed 4 years ago

tsaizhenling commented 4 years ago

tested on GPU GeForce RTX 2070 model: resnet18, traced with the following python script

import torch
import torchvision

# An instance of your model.
model = torchvision.models.resnet18()
model.eval()

# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)

# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
with torch.jit.optimized_execution(True):
    traced_script_module = torch.jit.trace(model, example)
    output = traced_script_module(torch.ones(1, 3, 224, 224))
    print(output[0, :5])
    traced_script_module.save("traced_resnet_model.pt")

bazel run //cpp/benchmark --cxxopt="-DNDEBUG" --cxxopt="-DJIT" --cxxopt="-DTRT" <path to .pt>"(1 3 224 224)"

[JIT/TRT]: batch_size: 1
    Average latency: 1.53636 ms
    Average FPS: 650.889 fps
    Latency Standard Deviation: 0.0887908
    FPS Standard Deviation: 39.0779
(excluding initial warmup runs)
[JIT]: batch_size: 1
    Average latency: 2.62867 ms
    Average FPS: 380.421 fps
    Latency Standard Deviation: 0.0781293
    FPS Standard Deviation: 14.7786
(excluding initial warmup runs)
ok

and another time with --cxxopt="-DHALF"

[JIT/TRT]: batch_size: 1
    Average latency: 0.921167 ms
    Average FPS: 1085.58 fps
    Latency Standard Deviation: 0.0322918
    FPS Standard Deviation: 33.7609
(excluding initial warmup runs)
[JIT]: batch_size: 1
    Average latency: 2.56868 ms
    Average FPS: 389.305 fps
    Latency Standard Deviation: 0.0382577
    FPS Standard Deviation: 6.43474
(excluding initial warmup runs)
ok

change L154 in cpp/ptq/main.cpp to match the dims above bazel run //cpp/ptq -- <path to .pt> <path to cifar10 data>

[JIT model FP32]: batch_size: 1
    Average latency: 2.64305 ms
    Average FPS: 378.35 fps
    Latency Standard Deviation: 0.0818512
    FPS Standard Deviation: 14.3568
(excluding initial warmup runs)
[TRT quantized model]: batch_size: 1
    Average latency: 1.42412 ms
    Average FPS: 702.189 fps
    Latency Standard Deviation: 0.0751042
    FPS Standard Deviation: 31.2702
(excluding initial warmup runs)

I observe little appreciable speedup from int8 quantization compared to full precision JIT/TRT. In fact, half precision is much faster. I understand that hardware plays a role and my gpu is probably not built for int8 inference, but I did not expect it to fare worse than half precision.

narendasan commented 4 years ago

Hi, Thanks for filing this. The large performance difference between FP16 and INT8 was caused by FP16 kernels not being enabled in INT8 mode. TensorRT will not restrict itself to only INT8 kernels unless strict types are specified. There are cases where an FP16 kernel may be faster than the INT8 kernel for a specific model, batch size and/or input size or an applicable INT8 kernel is not available. So without this setting TensorRT may fall back on FP32 explaining the large difference. Typically if INT8 mode is enabled but no INT8 layers are selected (falling back on FP16) you might see a slightly slower engine than pure FP16 because of format layers inserted to convert to/from FP32 data. PR #95 enables FP16 kernels for INT8 so I would try again with master and see if you see as large a performance gap now.

I tested resnet18 on an RTX 2080 Ti and saw better performance in INT8 (My guess is that primarily the kernels selected are INT8 in the 2080 Ti case where as for the 2070 more FP16 kernels will get selected). You can try turning on the debug mode and seeing what kernels are selected by TensorRT. For example here are the ones selected for me:

DEBUG: [TRTorch Conversion Context] - Engine Layer Information:
DEBUG: [TRTorch Conversion Context] - Layer(Reformat): %input0.1 : Tensor = aten::_convolution(%input.11, %self.conv1.weight, %9, %8, %7, %6, %10, %5, %13, %10, %10, %11) # /opt/conda/lib/python3.6/site-packages/torch/nn/modules/conv.py:346:0 + %114 : Tensor = aten::relu_(%input.16) # /opt/conda/lib/python3.6/site-packages/torch/nn/functional.py:1060:0 input reformatter 0, Tactic: 0, input.11[Float(3,224,224)] -> %input0.1 : Tensor = aten::_convolution(%input.11, %self.conv1.weight, %9, %8, %7, %6, %10, %5, %13, %10, %10, %11) # /opt/conda/lib/python3.6/site-packages/torch/nn/modules/conv.py:346:0 + %114 : Tensor = aten::relu_(%input.16) # /opt/conda/lib/python3.6/site-packages/torch/nn/functional.py:1060:0 reformatted input 0[Int8(3,224,224)]
DEBUG: [TRTorch Conversion Context] - Layer(): %input0.1 : Tensor = aten::_convolution(%input.11, %self.conv1.weight, %9, %8, %7, %6, %10, %5, %13, %10, %10, %11) # /opt/conda/lib/python3.6/site-packages/torch/nn/modules/conv.py:346:0 + %114 : Tensor = aten::relu_(%input.16) # /opt/conda/lib/python3.6/site-packages/torch/nn/functional.py:1060:0, Tactic: -5510956450195747703, %input0.1 : Tensor = aten::_convolution(%input.11, %self.conv1.weight, %9, %8, %7, %6, %10, %5, %13, %10, %10, %11) # /opt/conda/lib/python3.6/site-packages/torch/nn/modules/conv.py:346:0 + %114 : Tensor = aten::relu_(%input.16) # /opt/conda/lib/python3.6/site-packages/torch/nn/functional.py:1060:0 reformatted input 0[Int8(3,224,224)] -> 114[Int8(64,112,112)]
DEBUG: [TRTorch Conversion Context] - Layer(Pooling): %input.10 : Tensor = aten::max_pool2d(%114, %7, %8, %6, %6, %10) # /opt/conda/lib/python3.6/site-packages/torch/nn/functional.py:538:0, Tactic: -4, 114[Int8(64,112,112)] -> input.10[Int8(64,56,56)]
DEBUG: [TRTorch Conversion Context] - Layer(FusedConvActDirect): %input.14 : Tensor = aten::_convolution(%input.10, %self.layer1.0.conv1.weight, %9, %6, %6, %6, %10, %5, %13, %10, %10, %11) # /opt/conda/lib/python3.6/site-packages/torch/nn/modules/conv.py:346:0 + %118 : Tensor = aten::relu_(%input.12) # /opt/conda/lib/python3.6/site-packages/torch/nn/functional.py:1060:0, Tactic: 14, input.10[Int8(64,56,56)] -> 118[Int8(64,56,56)]
DEBUG: [TRTorch Conversion Context] - Layer(i8816cudnn): %input.13 : Tensor = aten::_convolution(%118, %self.layer1.0.conv2.weight, %9, %6, %6, %6, %10, %5, %13, %10, %10, %11) # /opt/conda/lib/python3.6/site-packages/torch/nn/modules/conv.py:346:0 + %input.15 : Tensor = aten::add_(%out.3, %input.10, %13) # /opt/conda/lib/python3.6/site-packages/torchvision/models/resnet.py:70:0 + %122 : Tensor = aten::relu_(%input.15) # /opt/conda/lib/python3.6/site-packages/torch/nn/functional.py:1060:0, Tactic: -8431788508843860955, 118[Int8(64,56,56)], input.10[Int8(64,56,56)] -> 122[Int8(64,56,56)]
DEBUG: [TRTorch Conversion Context] - Layer(FusedConvActDirect): %input.17 : Tensor = aten::_convolution(%122, %self.layer1.1.conv1.weight, %9, %6, %6, %6, %10, %5, %13, %10, %10, %11) # /opt/conda/lib/python3.6/site-packages/torch/nn/modules/conv.py:346:0 + %125 : Tensor = aten::relu_(%input.18) # /opt/conda/lib/python3.6/site-packages/torch/nn/functional.py:1060:0, Tactic: 14, 122[Int8(64,56,56)] -> 125[Int8(64,56,56)]
DEBUG: [TRTorch Conversion Context] - Layer(i8816cudnn): %input.19 : Tensor = aten::_convolution(%125, %self.layer1.1.conv2.weight, %9, %6, %6, %6, %10, %5, %13, %10, %10, %11) # /opt/conda/lib/python3.6/site-packages/torch/nn/modules/conv.py:346:0 + %input.20 : Tensor = aten::add_(%out.4, %122, %13) # /opt/conda/lib/python3.6/site-packages/torchvision/models/resnet.py:70:0 + %129 : Tensor = aten::relu_(%input.20) # /opt/conda/lib/python3.6/site-packages/torch/nn/functional.py:1060:0, Tactic: -496455309852654971, 125[Int8(64,56,56)], 122[Int8(64,56,56)] -> 129[Int8(64,56,56)]
DEBUG: [TRTorch Conversion Context] - Layer(i8816cudnn): %input.21 : Tensor = aten::_convolution(%129, %self.layer2.0.conv1.weight, %9, %8, %6, %6, %10, %5, %13, %10, %10, %11) # /opt/conda/lib/python3.6/site-packages/torch/nn/modules/conv.py:346:0 + %132 : Tensor = aten::relu_(%input.22) # /opt/conda/lib/python3.6/site-packages/torch/nn/functional.py:1060:0, Tactic: -8431788508843860955, 129[Int8(64,56,56)] -> 132[Int8(128,28,28)]
DEBUG: [TRTorch Conversion Context] - Layer(i8816cudnn): %input.24 : Tensor = aten::_convolution(%129, %self.layer2.0.downsample.0.weight, %9, %8, %5, %6, %10, %5, %13, %10, %10, %11) # /opt/conda/lib/python3.6/site-packages/torch/nn/modules/conv.py:346:0, Tactic: -3065443564418447033, 129[Int8(64,56,56)] -> identity.2[Int8(128,28,28)]
DEBUG: [TRTorch Conversion Context] - Layer(i8816cudnn): %input.23 : Tensor = aten::_convolution(%132, %self.layer2.0.conv2.weight, %9, %6, %6, %6, %10, %5, %13, %10, %10, %11) # /opt/conda/lib/python3.6/site-packages/torch/nn/modules/conv.py:346:0 + %input.25 : Tensor = aten::add_(%out.5, %identity.2, %13) # /opt/conda/lib/python3.6/site-packages/torchvision/models/resnet.py:70:0 + %138 : Tensor = aten::relu_(%input.25) # /opt/conda/lib/python3.6/site-packages/torch/nn/functional.py:1060:0, Tactic: -8431788508843860955, 132[Int8(128,28,28)], identity.2[Int8(128,28,28)] -> 138[Int8(128,28,28)]
DEBUG: [TRTorch Conversion Context] - Layer(FusedConvActDirect): %input.26 : Tensor = aten::_convolution(%138, %self.layer2.1.conv1.weight, %9, %6, %6, %6, %10, %5, %13, %10, %10, %11) # /opt/conda/lib/python3.6/site-packages/torch/nn/modules/conv.py:346:0 + %141 : Tensor = aten::relu_(%input.27) # /opt/conda/lib/python3.6/site-packages/torch/nn/functional.py:1060:0, Tactic: 14, 138[Int8(128,28,28)] -> 141[Int8(128,28,28)]
DEBUG: [TRTorch Conversion Context] - Layer(i8816cudnn): %input.28 : Tensor = aten::_convolution(%141, %self.layer2.1.conv2.weight, %9, %6, %6, %6, %10, %5, %13, %10, %10, %11) # /opt/conda/lib/python3.6/site-packages/torch/nn/modules/conv.py:346:0 + %input.29 : Tensor = aten::add_(%out.6, %138, %13) # /opt/conda/lib/python3.6/site-packages/torchvision/models/resnet.py:70:0 + %145 : Tensor = aten::relu_(%input.29) # /opt/conda/lib/python3.6/site-packages/torch/nn/functional.py:1060:0, Tactic: -8431788508843860955, 141[Int8(128,28,28)], 138[Int8(128,28,28)] -> 145[Int8(128,28,28)]
DEBUG: [TRTorch Conversion Context] - Layer(i8816cudnn): %input.30 : Tensor = aten::_convolution(%145, %self.layer3.0.conv1.weight, %9, %8, %6, %6, %10, %5, %13, %10, %10, %11) # /opt/conda/lib/python3.6/site-packages/torch/nn/modules/conv.py:346:0 + %148 : Tensor = aten::relu_(%input.31) # /opt/conda/lib/python3.6/site-packages/torch/nn/functional.py:1060:0, Tactic: -8431788508843860955, 145[Int8(128,28,28)] -> 148[Int8(256,14,14)]
DEBUG: [TRTorch Conversion Context] - Layer(i8816cudnn): %input.33 : Tensor = aten::_convolution(%145, %self.layer3.0.downsample.0.weight, %9, %8, %5, %6, %10, %5, %13, %10, %10, %11) # /opt/conda/lib/python3.6/site-packages/torch/nn/modules/conv.py:346:0, Tactic: -8431788508843860955, 145[Int8(128,28,28)] -> identity.3[Int8(256,14,14)]
DEBUG: [TRTorch Conversion Context] - Layer(i8816cudnn): %input.32 : Tensor = aten::_convolution(%148, %self.layer3.0.conv2.weight, %9, %6, %6, %6, %10, %5, %13, %10, %10, %11) # /opt/conda/lib/python3.6/site-packages/torch/nn/modules/conv.py:346:0 + %input.34 : Tensor = aten::add_(%out.7, %identity.3, %13) # /opt/conda/lib/python3.6/site-packages/torchvision/models/resnet.py:70:0 + %154 : Tensor = aten::relu_(%input.34) # /opt/conda/lib/python3.6/site-packages/torch/nn/functional.py:1060:0, Tactic: -8431788508843860955, 148[Int8(256,14,14)], identity.3[Int8(256,14,14)] -> 154[Int8(256,14,14)]
DEBUG: [TRTorch Conversion Context] - Layer(FusedConvActDirect): %input.35 : Tensor = aten::_convolution(%154, %self.layer3.1.conv1.weight, %9, %6, %6, %6, %10, %5, %13, %10, %10, %11) # /opt/conda/lib/python3.6/site-packages/torch/nn/modules/conv.py:346:0 + %157 : Tensor = aten::relu_(%input.36) # /opt/conda/lib/python3.6/site-packages/torch/nn/functional.py:1060:0, Tactic: 162, 154[Int8(256,14,14)] -> 157[Int8(256,14,14)]
DEBUG: [TRTorch Conversion Context] - Layer(i8816cudnn): %input.37 : Tensor = aten::_convolution(%157, %self.layer3.1.conv2.weight, %9, %6, %6, %6, %10, %5, %13, %10, %10, %11) # /opt/conda/lib/python3.6/site-packages/torch/nn/modules/conv.py:346:0 + %input.38 : Tensor = aten::add_(%out.8, %154, %13) # /opt/conda/lib/python3.6/site-packages/torchvision/models/resnet.py:70:0 + %161 : Tensor = aten::relu_(%input.38) # /opt/conda/lib/python3.6/site-packages/torch/nn/functional.py:1060:0, Tactic: -8431788508843860955, 157[Int8(256,14,14)], 154[Int8(256,14,14)] -> 161[Int8(256,14,14)]
DEBUG: [TRTorch Conversion Context] - Layer(i8816cudnn): %input.6 : Tensor = aten::_convolution(%161, %self.layer4.0.conv1.weight, %9, %8, %6, %6, %10, %5, %13, %10, %10, %11) # /opt/conda/lib/python3.6/site-packages/torch/nn/modules/conv.py:346:0 + %164 : Tensor = aten::relu_(%input.7) # /opt/conda/lib/python3.6/site-packages/torch/nn/functional.py:1060:0, Tactic: -8431788508843860955, 161[Int8(256,14,14)] -> 164[Int8(512,7,7)]
DEBUG: [TRTorch Conversion Context] - Layer(i8816cudnn): %input.8 : Tensor = aten::_convolution(%161, %self.layer4.0.downsample.0.weight, %9, %8, %5, %6, %10, %5, %13, %10, %10, %11) # /opt/conda/lib/python3.6/site-packages/torch/nn/modules/conv.py:346:0, Tactic: 4871133328510103657, 161[Int8(256,14,14)] -> identity.1[Int8(512,7,7)]
DEBUG: [TRTorch Conversion Context] - Layer(i8816cudnn): %input.5 : Tensor = aten::_convolution(%164, %self.layer4.0.conv2.weight, %9, %6, %6, %6, %10, %5, %13, %10, %10, %11) # /opt/conda/lib/python3.6/site-packages/torch/nn/modules/conv.py:346:0 + %input.9 : Tensor = aten::add_(%out.2, %identity.1, %13) # /opt/conda/lib/python3.6/site-packages/torchvision/models/resnet.py:70:0 + %170 : Tensor = aten::relu_(%input.9) # /opt/conda/lib/python3.6/site-packages/torch/nn/functional.py:1060:0, Tactic: -8431788508843860955, 164[Int8(512,7,7)], identity.1[Int8(512,7,7)] -> 170[Int8(512,7,7)]
DEBUG: [TRTorch Conversion Context] - Layer(FusedConvActDirect): %input.3 : Tensor = aten::_convolution(%170, %self.layer4.1.conv1.weight, %9, %6, %6, %6, %10, %5, %13, %10, %10, %11) # /opt/conda/lib/python3.6/site-packages/torch/nn/modules/conv.py:346:0 + %173 : Tensor = aten::relu_(%input.4) # /opt/conda/lib/python3.6/site-packages/torch/nn/functional.py:1060:0, Tactic: 162, 170[Int8(512,7,7)] -> 173[Int8(512,7,7)]
DEBUG: [TRTorch Conversion Context] - Layer(i8816cudnn): %input.1 : Tensor = aten::_convolution(%173, %self.layer4.1.conv2.weight, %9, %6, %6, %6, %10, %5, %13, %10, %10, %11) # /opt/conda/lib/python3.6/site-packages/torch/nn/modules/conv.py:346:0 + %input.2 : Tensor = aten::add_(%out.1, %170, %13) # /opt/conda/lib/python3.6/site-packages/torchvision/models/resnet.py:70:0 + %177 : Tensor = aten::relu_(%input.2) # /opt/conda/lib/python3.6/site-packages/torch/nn/functional.py:1060:0, Tactic: -8431788508843860955, 173[Int8(512,7,7)], 170[Int8(512,7,7)] -> 177[Int8(512,7,7)]
DEBUG: [TRTorch Conversion Context] - Layer(Reformat): %x.1 : Tensor = aten::adaptive_avg_pool2d(%177, %6) # /opt/conda/lib/python3.6/site-packages/torch/nn/functional.py:888:0 input reformatter 0, Tactic: 0, 177[Int8(512,7,7)] -> %x.1 : Tensor = aten::adaptive_avg_pool2d(%177, %6) # /opt/conda/lib/python3.6/site-packages/torch/nn/functional.py:888:0 reformatted input 0[Float(512,7,7)]
DEBUG: [TRTorch Conversion Context] - Layer(PoolingTiled): %x.1 : Tensor = aten::adaptive_avg_pool2d(%177, %6) # /opt/conda/lib/python3.6/site-packages/torch/nn/functional.py:888:0, Tactic: 7995649, %x.1 : Tensor = aten::adaptive_avg_pool2d(%177, %6) # /opt/conda/lib/python3.6/site-packages/torch/nn/functional.py:888:0 reformatted input 0[Float(512,7,7)] -> x.1[Float(512,1,1)]
DEBUG: [TRTorch Conversion Context] - Layer(Constant): %180 : Tensor = aten::matmul(%input0.2, %2) [Freeze Tensor(other)], Tactic: 0,  -> (Unnamed Layer* 68) [Constant]_output[Float(1000)]
DEBUG: [TRTorch Conversion Context] - Layer(MatrixMultiply): %180 : Tensor = aten::matmul(%input0.2, %2), Tactic: 0, input0.2[Float(512)], (Unnamed Layer* 68) [Constant]_output[Float(1000)] -> 180[Float(1000)]
DEBUG: [TRTorch Conversion Context] - Layer(Constant): %181 : Tensor = trt::const(%self.fc.bias) + [Reshape self to [1000] for broadcasting (%182 : Tensor = aten::add_(%181, %180, %13))], Tactic: 0,  -> (Unnamed Layer* 71) [Shuffle]_output[Float(1000)]
DEBUG: [TRTorch Conversion Context] - Layer(ElementWise): %182 : Tensor = aten::add_(%181, %180, %13), Tactic: 1, (Unnamed Layer* 71) [Shuffle]_output[Float(1000)], 180[Float(1000)] -> 182[Float(1000)]

Also in order to properly run the ptq application on a network like Resnet you need to apply this patch or one similar based on input size:

ptq.patch.txt

tsaizhenling commented 4 years ago

thanks for the detailed explanation @narendasan here are my numbers after quantizing successfully

[TRT quantized model]: batch_size: 1
    Average latency: 0.639157 ms
    Average FPS: 1564.56 fps
    Latency Standard Deviation: 0.0474436
    FPS Standard Deviation: 84.4725