NVIDIA / TensorRT

NVIDIA® TensorRT™ is an SDK for high-performance deep learning inference on NVIDIA GPUs. This repository contains the open source components of TensorRT.
https://developer.nvidia.com/tensorrt
Apache License 2.0
10.59k stars 2.11k forks source link

INT8 engine build failure of TensorRT 8.6.1 on RTX 3090 #3020

Closed sangxia closed 1 year ago

sangxia commented 1 year ago

Description

I am trying to build an engine from ONNX. It is a legacy model that has worked for me in the past (on e.g. TensorRT 8.0). I've managed to isolate to the subgraph (I believe) that is causing the problem subgraph.zip

When running the following in NGC pytorch container 23.04

polygraphy convert subgraph.onnx --convert-to trt --int8 --output tmp.trt

It ends with the following error message

[E] 2: Assertion getter(i) != 0 failed.
[E] 2: [weightConvertors.cpp::quantizeBiasCommon::310] Error Code 2: Internal Error (Assertion getter(i) != 0 failed. )
[!] Invalid Engine. Please ensure the engine was built correctly

I suspect one issue is that the weights are all very small. However this is a subgraph from a legacy model so there is not much I can do about it. Until recently, I've been using NGC TRT container 21.08, which has TensorRT 8.0 and the build works. It also worked on Jetpack 5.0.2 (TRT 8.4.1). Do you have any suggestion of workarounds? Thank you.

Environment

NGC Pytorch container 23.04

zerollzeng commented 1 year ago

I didn't reproduce the issue with trtexec or polygraphy on my side:

$ polygraphy convert subgraph.onnx --convert-to trt --int8 --output tmp.trt
[W] Int8 Calibration is using randomly generated input data.
    This could negatively impact accuracy if the inference-time input data is dissimilar to the randomly generated calibration data.
    You may want to consider providing real data via the --data-loader-script option.
[W] CUDA lazy loading is not enabled. Enabling it can significantly reduce device memory usage and speed up TensorRT initialization. See "Lazy Loading" section of CUDA documentation https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#lazy-loading
[I]     Configuring with profiles: [Profile().add('/stage4/stage4.2/branches.0/branches.0.0/relu_1/Relu_output_0', min=[1, 32, 64, 48], opt=[1, 32, 64, 48], max=[1, 32, 64, 48])]
[I] Building engine with configuration:
    Flags                  | [INT8]
    Engine Capability      | EngineCapability.DEFAULT
    Memory Pools           | [WORKSPACE: 22539.12 MiB, TACTIC_DRAM: 22539.12 MiB]
    Tactic Sources         | [CUBLAS, CUBLAS_LT, CUDNN, EDGE_MASK_CONVOLUTIONS, JIT_CONVOLUTIONS]
    Profiling Verbosity    | ProfilingVerbosity.DETAILED
    Preview Features       | [FASTER_DYNAMIC_SHAPES_0805, DISABLE_EXTERNAL_TACTIC_SOURCES_FOR_CORE_0805]
    Calibrator             | Calibrator(DataLoader(seed=1, iterations=1, int_range=(1, 25), float_range=(-1.0, 1.0), val_range=(0.0, 1.0)), BaseClass=<class 'tensorrt.tensorrt.IInt8EntropyCalibrator2'>)
[I] Finished engine building in 2.186 seconds

I can see the onnx is very simple so I think it's unlikely it will trigger the error, could you please double check? Many thanks!

sangxia commented 1 year ago

Hi @zerollzeng

Thanks for checking. I tried again and got the same error. Please see full log below. I have also checked that the ONNX file I uploaded is the same.

$ docker run -it --gpus all --ipc host --shm-size 32G -v $(pwd):/workspace nvcr.io/nvidia/pytorch:23.04-py3

=============
== PyTorch ==
=============

NVIDIA Release 23.04 (build 58180998)
PyTorch Version 2.1.0a0+fe05266

Container image Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.

Copyright (c) 2014-2023 Facebook Inc.
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
Copyright (c) 2012-2014 Deepmind Technologies    (Koray Kavukcuoglu)
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
Copyright (c) 2011-2013 NYU                      (Clement Farabet)
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
Copyright (c) 2006      Idiap Research Institute (Samy Bengio)
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
Copyright (c) 2015      Google Inc.
Copyright (c) 2015      Yangqing Jia
Copyright (c) 2013-2016 The Caffe contributors
All rights reserved.

Various files include modifications (c) NVIDIA CORPORATION & AFFILIATES.  All rights reserved.

This container image and its contents are governed by the NVIDIA Deep Learning Container License.
By pulling and using the container, you accept the terms and conditions of this license:
https://developer.nvidia.com/ngc/nvidia-deep-learning-container-license

WARNING: CUDA Minor Version Compatibility mode ENABLED.
  Using driver version 525.105.17 which has support for CUDA 12.0.  This container
  was built with CUDA 12.1 and will be run in Minor Version Compatibility mode.
  CUDA Forward Compatibility is preferred over Minor Version Compatibility for use
  with this container but was unavailable:
  [[Forward compatibility was attempted on non supported HW (CUDA_ERROR_COMPAT_NOT_SUPPORTED_ON_DEVICE) cuInit()=804]]
  See https://docs.nvidia.com/deploy/cuda-compatibility/ for details.

root@a18116daff6f:/workspace# polygraphy convert subgraph.onnx --convert-to trt --int8 --output tmp.trt
[W] 'colored' module is not installed, will not use colors when logging. To enable colors, please install the 'colored' module: python3 -m pip install colored
[W] Int8 Calibration is using randomly generated input data.
    This could negatively impact accuracy if the inference-time input data is dissimilar to the randomly generated calibration data.
    You may want to consider providing real data via the --data-loader-script option.
[I]     Configuring with profiles: [Profile().add('/stage4/stage4.2/branches.0/branches.0.0/relu_1/Relu_output_0', min=[1, 32, 64, 48], opt=[1, 32, 64, 48], max=[1, 32, 64, 48])]
[I] Building engine with configuration:
    Flags                  | [INT8]
    Engine Capability      | EngineCapability.DEFAULT
    Memory Pools           | [WORKSPACE: 24237.50 MiB, TACTIC_DRAM: 24237.50 MiB]
    Tactic Sources         | [CUBLAS, CUBLAS_LT, CUDNN, EDGE_MASK_CONVOLUTIONS, JIT_CONVOLUTIONS]
    Profiling Verbosity    | ProfilingVerbosity.DETAILED
    Preview Features       | [FASTER_DYNAMIC_SHAPES_0805, DISABLE_EXTERNAL_TACTIC_SOURCES_FOR_CORE_0805]
    Calibrator             | Calibrator(DataLoader(seed=1, iterations=1, int_range=(1, 25), float_range=(-1.0, 1.0), val_range=(0.0, 1.0)), BaseClass=<class 'tensorrt.tensorrt.IInt8EntropyCalibrator2'>)
[E] 2: Assertion getter(i) != 0 failed.
[E] 2: [weightConvertors.cpp::quantizeBiasCommon::310] Error Code 2: Internal Error (Assertion getter(i) != 0 failed. )
[!] Invalid Engine. Please ensure the engine was built correctly
root@a18116daff6f:/workspace# md5sum subgraph.onnx
eb7b40870daefae4fa42f5e786f0fee8  subgraph.onnx
zerollzeng commented 1 year ago

I can reproduce the issue locally. I will check the issue with our latest code and come back to you.

sangxia commented 1 year ago

I added the --calibration-cache argument to have a look at the calibration cache. I got the following:

TRT-8601-EntropyCalibration2
/stage4/stage4.2/branches.0/branches.0.0/relu_1/Relu_output_0: 3c0109c2
/stage4/stage4.2/branches.0/branches.0.1/conv1/Conv_output_0: 34166e31
/stage4/stage4.2/branches.0/branches.0.1/relu/Relu_output_0: 0
/stage4/stage4.2/branches.0/branches.0.1/conv2/Conv_output_0: 3ab90255

and if I calibrate with this cache, it fails like above. However, if I manually change the line with "0" and use the following:

TRT-8601-EntropyCalibration2
/stage4/stage4.2/branches.0/branches.0.0/relu_1/Relu_output_0: 3c0109c2
/stage4/stage4.2/branches.0/branches.0.1/conv1/Conv_output_0: 34166e31
/stage4/stage4.2/branches.0/branches.0.1/relu/Relu_output_0: 3c0109c2
/stage4/stage4.2/branches.0/branches.0.1/conv2/Conv_output_0: 3ab90255

Then it succeeds.

zerollzeng commented 1 year ago

/stage4/stage4.2/branches.0/branches.0.1/relu/Relu_output_0: 0

It's weird that the scale is 0. I checked the latest internal code and it pass. I think we would need further investigation on this.

ttyio commented 1 year ago

@zerollzeng , could you repro? thanks

zerollzeng commented 1 year ago

I can reproduce the 0 scale

$ polygraphy convert subgraph.onnx --convert-to trt --int8 --output tmp.trt --calibration-cache tmp.cache
[W] Int8 Calibration is using randomly generated input data.
    This could negatively impact accuracy if the inference-time input data is dissimilar to the randomly generated calibration data.
    You may want to consider providing real data via the --data-loader-script option.
[W] CUDA lazy loading is not enabled. Enabling it can significantly reduce device memory usage and speed up TensorRT initialization. See "Lazy Loading" section of CUDA documentation https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#lazy-loading
[I]     Configuring with profiles: [Profile().add('/stage4/stage4.2/branches.0/branches.0.0/relu_1/Relu_output_0', min=[1, 32, 64, 48], opt=[1, 32, 64, 48], max=[1, 32, 64, 48])]
[I] Building engine with configuration:
    Flags                  | [INT8]
    Engine Capability      | EngineCapability.DEFAULT
    Memory Pools           | [WORKSPACE: 22517.44 MiB, TACTIC_DRAM: 22517.44 MiB]
    Tactic Sources         | [CUBLAS, CUBLAS_LT, CUDNN, EDGE_MASK_CONVOLUTIONS, JIT_CONVOLUTIONS]
    Profiling Verbosity    | ProfilingVerbosity.DETAILED
    Preview Features       | [FASTER_DYNAMIC_SHAPES_0805, DISABLE_EXTERNAL_TACTIC_SOURCES_FOR_CORE_0805]
    Calibrator             | Calibrator(DataLoader(seed=1, iterations=1, int_range=(1, 25), float_range=(-1.0, 1.0), val_range=(0.0, 1.0)), cache='tmp.cache', BaseClass=<class 'tensorrt.tensorrt.IInt8EntropyCalibrator2'>)
[I] Saving calibration cache to tmp.cache
[I] Finished engine building in 2.890 seconds
$ ls
subgraph.onnx  subgraph.zip  tmp.cache  tmp.trt
$ cat tmp.cache 
TRT-8601-EntropyCalibration2
/stage4/stage4.2/branches.0/branches.0.0/relu_1/Relu_output_0: 3c0109c2
/stage4/stage4.2/branches.0/branches.0.1/conv1/Conv_output_0: 34166e31
/stage4/stage4.2/branches.0/branches.0.1/relu/Relu_output_0: 0
/stage4/stage4.2/branches.0/branches.0.1/conv2/Conv_output_0: 3ab90255
zerollzeng commented 1 year ago

@ttyio Can this be caused by random input data? this network is pretty simple so I would surprise if there is bug. image

ttyio commented 1 year ago

I checked the onnx, the conv weights are too small, and polygraphy using input random from range (0, 1) by default, so it produce 0 during calibration. @sangxia Is this portion of real network? pls use the data that captured from real input for calibration, thanks!

sangxia commented 1 year ago

Hi @ttyio , it's isolated from a much larger network. I noticed the issue when running calibration for the full network with real data. I agree that the weights are very small, but I don't have much control over the training process. It is thus likely that the output of some of the layers are essentially all zero over the calibration data, so it's kind of a degenerate case and I wonder if this triggered some issue in TensorRT.

Also note that I didn't have this problem with the same network in earlier versions. It worked for me on TRT 8.0 and 8.4, but not 8.5 or 8.6.

ttyio commented 1 year ago

@sangxia could you try adding more calibration data? once TRT observe one calibration batch that produce none-zero output, the calibration scale updated to a none zero value. The WAR might be copy the calibration scales from legacy TRT calibration cache. It's unclean but works.

Also could you try turn off TF32 by export NVIDIA_TF32_OVERRIDE=0 before calibration if your GPU supports TF32.

sangxia commented 1 year ago

@ttyio When I tried to calibrate the full model (not the subgraph here), I used 512 samples but it didn't help. I'm not sure how to set how much calibration data to use with polygraphy. I tried the --iters argument like below but it didn't help.

polygraphy convert subgraph.onnx --convert-to trt --int8 --output tmp.trt --calibration-cache tmp.cache --iters 256

I looked more into the weights of the first conv layer. The W values are all very small, and the B values are all negative, so it's likely that the outputs of this layer are all negative, which makes the output of the next Relu layer all 0 for any inputs of reasonable magnitude. This also matches the output in the calibration cache above

/stage4/stage4.2/branches.0/branches.0.1/relu/Relu_output_0: 0

Since the network has worked in the past and retraining is not an option, I am just looking for some ways to get around the issue. As I can get this subgraph to work by manually changing the calibration cache https://github.com/NVIDIA/TensorRT/issues/3020#issuecomment-1568647965, I wonder if this could be a way. Could you explain the format of the calibration cache, in particular, how to interpret values like 3c0109c2?

ttyio commented 1 year ago

@sangxia , yes manually changing this works, the 3c0109c2 is just hex encoding of the fp32 number (0.00787586)

for your polygraphy command line, I did not see how do you provide the input calibration data, did you use --load-inputs ?

sangxia commented 1 year ago

for your polygraphy command line, I did not see how do you provide the input calibration data, did you use --load-inputs ?

I had a data feeder when quantizing the whole network, but I didn't implement a separate one for the subgraph. I thought polygraphy in this case generates random inputs. In any case, I think I'll just manually modify the calibration cache to get around the problem, and I'll close the issue now. Thanks for the explanations.

p3achyjr commented 1 year ago

fwiw I am also running into this issue. Here is a partial paste of my calibration cache: https://pastebin.com/7252gvGy

How can I interpret the results of this cache? Why is it an issue of the outputs are 0?

sangxia commented 1 year ago

My interpretation is that the values in the cache are the scales as explained here https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#working-with-int8 . If the scale is zero, then perhaps there would be some division-by-zero problem depending on the implementation.