NVIDIA / TensorRT

NVIDIA® TensorRT™ is an SDK for high-performance deep learning inference on NVIDIA GPUs. This repository contains the open source components of TensorRT.
https://developer.nvidia.com/tensorrt
Apache License 2.0
10.49k stars 2.1k forks source link

Obvious performance degradation for TRT-fp16 model compared to original Pytorch model #2942

Open KingofAsianPopJC opened 1 year ago

KingofAsianPopJC commented 1 year ago

Description

The performence of TRT-fp32 and OnnxRuntime is equal to the original Pytorch model, while there is obvious performance degradation in TRT-fp16,what is the reason and how to solve it?

image

Environment

Pytorch: 2.0.0 CUDA: 11.4 Cudnn: 8.6.0 TensorRT: 8.5-GA Graphic Cards: Nvidia A100 GPU Driver version: 515.86.01 Operating System: Ubuntu 20.04 Python: 3.10

If there's need to modify some layers or operation of model to improve the performance of TRT-fp16, how to locate these layers or operations?

zerollzeng commented 1 year ago

Check the TRT log to see if there are warning about FP16 underflow/overflow.

You can fallback some layers back to FP32 by using techniques like bisecting or per-layer diff analysis.

konev-artem commented 1 year ago

@zerollzeng Hi, I've been debugging a similar issue (obvious trt-fp16 quality degradation for a diffusion model in my case) and currently, I firmly believe that the promblem is in the fp16 accumulator type in the convolution kernels, leading to error accumulation with the network depth.

The reasons to believe are:

  1. Torch-fp16 model works fine, and it uses kernels that looks like, for example, "sm70_xmma_fprop_implicit_gemm_indexed_f16f16_f16f32_f32_nhwckrsc_nhwc_tilesize64x256x64_stage1_warpsize1x4x2_g1_tensor8x8x4_kernel_cudnn" while trt-fp16 uses kernels with "f16f16_f16f16_f16" substing - I suspect the difference stands for accumulator type.

  2. If I limit TensorRT builder to use only cudnn kernels via the IAlgorithmSelector interface, then trt-16 engine works fine (though slower).

  3. I validated that on a synthetic example with one-layer network that doesn't have any fp16 out-of-the-range issues.

So, I wonder if there is any way to force TRT to use "f16f16_f16f32_f32" kernels like Torch does?

zerollzeng commented 1 year ago

Some kernel implementations may be fast but prone to overflow.

So, I wonder if there is any way to force TRT to use "f16f16_f16f32_f32" kernels like Torch does?

I don't know much about the IAlgoritmSelector, @nvpohanh @pranavm-nvidia can you kindly help here? Thanks!

konev-artem commented 1 year ago

Hi @zerollzeng ,

Thank you for your reply.

Regarding your idea about fast but accurate implementations, I scrutinized an isolated 3x3 convolution layer. Here are my results: https://pastebin.com/raw/xXrh85E7 (I print out the norm of the difference between a particular algorithm and a baseline fp32 implementation.)

Then, I found that:

  1. kernels # 0-74 are fp32.
  2. kernels # 75-83 are the only fp16 kernels with decent accuracy. They have "cudnn" strings in their names, so I assume they are taken from CuDNN. However, they seem to be slow because they are usually not chosen by TensorRT - "implicit_gemm" kernels are usually chosen instead. In fact, this difference coincided with the one I got with Torch-fp16 (but Torch uses "implicit_gemm_f16f16_f16f32_f32" kernels).
  3. All other kernels after # 83 have unacceptable diff.

So, the only option I currently have is to choose from this small and probably slow group of kernels. But that's only part of the problem - some layers in my network do not have even these CuDNN fp16 implementations, and in that case, I am forced to choose among fp32 kernels only.

nvpohanh commented 1 year ago

TRT does not have APIs to let users specify the accumulator type yet. @zerollzeng Could you try repro this and file an internal tracker so that we track this feature request and potentially find some workarounds? Thanks

zerollzeng commented 1 year ago

@konev-artem Could you please provide a reproduce for this issue? We can take a further investigation, thanks!

konev-artem commented 1 year ago

@zerollzeng You can find the reproducer here: https://pastebin.com/raw/ZUJS8vqd It produces an outcome comparable to the one I've previously shared: https://pastebin.com/raw/na7bsUKD These results were obtained with a V100 GPU, torch-tensorrt==1.3.0, torch==1.13.1+cu116, tensorrt==8.5.1.7

zerollzeng commented 1 year ago

I see in your reproduce there is only a conv layer, could you please share the corresponding onnx here so I don't need to preapre the env. I can do some experiment with TensorRT and onnxruntime. e.g. with polygraphy

polygraphy run model.onnx --trt --fp16 --onnxrt

Thanks!

konev-artem commented 1 year ago

Sure, here is the onnx file: https://github.com/konev-artem/temp/blob/06b1f3f4d0f8e866568ef9e69001857f87b5ba3f/model.onnx

zerollzeng commented 1 year ago

From what I can see looks like the diff is not very bad, but FP16 indeed has a smaller range than FP32. so it should be the reason that you see the final diff, how about just run the model in FP32?

[I]         Tolerance: [abs=1e-05, rel=1e-05] | Checking elemwise error
[I]         trt-runner-N0-05/30/23-15:20:02: output | Stats: mean=-0.021959, std-dev=0.34085, var=0.11618, median=-0.021469, min=-1.4941 at (0, 249, 60, 48), max=1.4404 at (0, 122, 49, 41), avg-magnitude=0.27263
[I]             ---- Histogram ----
                Bin Range          |  Num Elems | Visualization
                (-1.49  , -1.2   ) |        226 |
                (-1.2   , -0.907 ) |       9304 |
                (-0.907 , -0.614 ) |      75587 | #####
                (-0.614 , -0.32  ) |     287245 | ###################
                (-0.32  , -0.0269) |     599313 | #######################################
                (-0.0269, 0.267  ) |     602034 | ########################################
                (0.267  , 0.56   ) |     309292 | ####################
                (0.56   , 0.854  ) |      76736 | #####
                (0.854  , 1.15   ) |       8055 |
                (1.15   , 1.44   ) |        336 |
[I]         onnxrt-runner-N0-05/30/23-15:20:02: output | Stats: mean=-0.021961, std-dev=0.34084, var=0.11617, median=-0.021475, min=-1.4875 at (0, 249, 60, 48), max=1.4387 at (0, 122, 49, 41), avg-magnitude=0.27263
[I]             ---- Histogram ----
                Bin Range          |  Num Elems | Visualization
                (-1.49  , -1.2   ) |        231 |
                (-1.2   , -0.907 ) |       9309 |
                (-0.907 , -0.614 ) |      75694 | #####
                (-0.614 , -0.32  ) |     287336 | ###################
                (-0.32  , -0.0269) |     599159 | #######################################
                (-0.0269, 0.267  ) |     602138 | ########################################
                (0.267  , 0.56   ) |     309271 | ####################
                (0.56   , 0.854  ) |      76613 | #####
                (0.854  , 1.15   ) |       8041 |
                (1.15   , 1.44   ) |        336 |
[I]         Error Metrics: output
[I]             Minimum Required Tolerance: elemwise error | [abs=0.014879] OR [rel=4932.8] (requirements may be lower if both abs/rel tolerances are set)
[I]             Absolute Difference | Stats: mean=0.00085387, std-dev=0.00086303, var=7.4482e-07, median=0.00059594, min=0 at (0, 73, 10, 44), max=0.014879 at (0, 180, 12, 38), avg-magnitude=0.00085387
[I]                 ---- Histogram ----
                    Bin Range          |  Num Elems | Visualization
                    (0      , 0.00149) |    1643314 | ########################################
                    (0.00149, 0.00298) |     262423 | ######
                    (0.00298, 0.00446) |      48631 | #
                    (0.00446, 0.00595) |      10674 |
                    (0.00595, 0.00744) |       2459 |
                    (0.00744, 0.00893) |        473 |
                    (0.00893, 0.0104 ) |        123 |
                    (0.0104 , 0.0119 ) |         24 |
                    (0.0119 , 0.0134 ) |          6 |
                    (0.0134 , 0.0149 ) |          1 |
[I]             Relative Difference | Stats: mean=0.022086, std-dev=4.7962, var=23.004, median=0.0029441, min=0 at (0, 73, 10, 44), max=4932.8 at (0, 133, 6, 22), avg-magnitude=0.022086
[I]                 ---- Histogram ----
                    Bin Range            |  Num Elems | Visualization
                    (0       , 493     ) |    1968120 | ########################################
                    (493     , 987     ) |          3 |
                    (987     , 1.48e+03) |          1 |
                    (1.48e+03, 1.97e+03) |          1 |
                    (1.97e+03, 2.47e+03) |          1 |
                    (2.47e+03, 2.96e+03) |          0 |
                    (2.96e+03, 3.45e+03) |          1 |
                    (3.45e+03, 3.95e+03) |          0 |
                    (3.95e+03, 4.44e+03) |          0 |
                    (4.44e+03, 4.93e+03) |          1 |
zerollzeng commented 1 year ago

Another solution is fallback some of the layers back to FP32, this would requite some trial-and-test until you get a needed result.

konev-artem commented 1 year ago

Running the model in FP32 is an option but it would be less efficient given the intensive use of this model in production. I've tried to switch to FP32 for various parts of the model; however, it seems like the fp16 error is evenly distributed across conv layers of the model without any specific weak points.

That led me to a conclusion, since I observed a significantly smaller error in pytorch-fp16, that a potential solution would be to support the same "implicit_gemm_f16f16_f16f32_f32" kernels which I believe have differences in accumulator type compared to TensorRT.