NVIDIA / TensorRT

NVIDIA® TensorRT™ is an SDK for high-performance deep learning inference on NVIDIA GPUs. This repository contains the open source components of TensorRT.
https://developer.nvidia.com/tensorrt
Apache License 2.0
10.73k stars 2.12k forks source link

How does TensorRT implements the `Add` in INT8 mode ? #1144

Closed JosephChenHub closed 3 years ago

JosephChenHub commented 3 years ago

Description

Currently, we did the comparison experiments with the following two settings:

  1. via the tool pytorch_quantization , we can parse the exported onnx model with Q/DQ nodes to generate the calibration table, and then obtain the TensorRT engine.
  2. directly use the TensorRT to do PTQ calibration .

However, we observe that the inference speed get worse than FP16 in method 1. comp where TRT means the engine generated via method 2, and QDQ refers to the engine generated by method 1. By checking the log, we can find that some layers of method 1 still remains FP32, e.g, Screenshot from 2021-03-23 14-10-47 The Conv_101 +Relu_108, and Add are still in FP32 mode, and the Conv_101 is the first conv. block as shown in the exported onnx model . So questions come that

  1. Why does the Conv_101+Relu_108 fail to converted into INT8 mode?
  2. How is the Add operator implemented when INT8 mode, as the following log of method 2 show ? Screenshot from 2021-03-23 14-16-11 For example, z = x + y, where max(abs(x)) = 2.5 and max(abs(y)) = 2.5, we can add the quantization numbers first and then dequantize the result, but what if they have different maximum range?

Environment

TensorRT Version: 7.2.1 NVIDIA GPU: GTX 2080Ti NVIDIA Driver Version: 440 CUDA Version: 10.2 CUDNN Version: 8.0 Operating System: Ubuntu 18.04 Python Version (if applicable): 3.6 Tensorflow Version (if applicable):
PyTorch Version (if applicable): 1.7.1 Baremetal or Container (if so, version):

Relevant Files

part of the exported onnx model with Q/DQ nodes:

image

ttyio commented 3 years ago

Hello @JosephChenHub , TRT can fuse conv + relu + add together, but the fused output and residual has to be the same data type, and the monkey patch in QAT cannot reconignize this conv + relu + residual pattern, could you follow the linked code to add quant in the residual? thanks!

https://github.com/NVIDIA/TensorRT/blob/master/tools/pytorch-quantization/examples/torchvision/models/classification/resnet.py#L219

JosephChenHub commented 3 years ago

Hello @JosephChenHub , TRT can fuse conv + relu + add together, but the fused output and residual has to be the same data type, and the monkey patch in QAT cannot reconignize this conv + relu + residual pattern, could you follow the linked code to add quant in the residual? thanks!

https://github.com/NVIDIA/TensorRT/blob/master/tools/pytorch-quantization/examples/torchvision/models/classification/resnet.py#L219

hi, thanks! I add the quant in the residual but it still fails to be converted into INT8 mode. Does the concat affect this fusion ? image

ttyio commented 3 years ago

Hello @JosephChenHub , Yes the concat here would break the fusion pattern I mentioned before. is it a different case? I did not see concat in the original description.

JosephChenHub commented 3 years ago

Hello @JosephChenHub , Yes the concat here would break the fusion pattern I mentioned before. is it a different case? I did not see concat in the original description.

I see, the concat is indeed a copy operation by TRT and it falls into fp32 mode. So how to prevent it ?

[TensorRT] VERBOSE: Layer(ElementWise): Add_132, Tactic: 1, 451[Float(32,160,160)], Conv_142 + Relu_149 || Conv_84 + Relu_91[Float(32,160,160)] -> 478[Float(32,160,160)]
[TensorRT] VERBOSE: Layer(Reformat): 477 copy, Tactic: 0, Conv_142 + Relu_149 || Conv_84 + Relu_91[Float(32,160,160)] -> 478[Float(32,160,160)]
ttyio commented 3 years ago

Hello @JosephChenHub , do you have the full verbose log that I can check? thanks!

JosephChenHub commented 3 years ago

Hello @JosephChenHub , do you have the full verbose log that I can check? thanks!

Hi @ttyio , I insert the quantizers into the Add and Concat and the evaluation result is 0.250mAP/544FPS while the result of TRT is 0.327mAP/544FPS. I guess the degrade resulted from the Add or Concat, since they have two different range scale as shown in the following graph.

image

JosephChenHub commented 3 years ago

Hello @JosephChenHub , do you have the full verbose log that I can check? thanks!

Hi @ttyio , I insert the quantizers into the Add and Concat and the evaluation result is 0.250mAP/544FPS while the result of TRT is 0.327mAP/544FPS. I guess the degrade resulted from the Add or Concat, since they have two different range scale as shown in the following graph.

image

OK, I have solved this issue, and it's interesting that the result reaches up to 0.340mAP/544FPS (TRT: 0.327/544).

ttyio commented 3 years ago

Cool @JosephChenHub , for the concat question, TRT support per channel scale internally for concat.

may I know how you fix the issue? thanks!

weixiaolian21 commented 3 years ago

Hello @JosephChenHub , do you have the full verbose log that I can check? thanks!

Hi @ttyio , I insert the quantizers into the Add and Concat and the evaluation result is 0.250mAP/544FPS while the result of TRT is 0.327mAP/544FPS. I guess the degrade resulted from the Add or Concat, since they have two different range scale as shown in the following graph. image

OK, I have solved this issue, and it's interesting that the result reaches up to 0.340mAP/544FPS (TRT: 0.327/544).

hello, I met the same issue with you, can you share the way you sovle it? may I have your QQ or Wechat to learn from you?

imyhxy commented 3 years ago

Hi, would you share how you solve this problem?

Mavericky-j commented 2 years ago

does someone have ideas about dealing with the different scale before add node?