NVIDIA / TensorRT

NVIDIA® TensorRT™ is an SDK for high-performance deep learning inference on NVIDIA GPUs. This repository contains the open source components of TensorRT.
https://developer.nvidia.com/tensorrt
Apache License 2.0
10.61k stars 2.11k forks source link

Failed INT8 quantization. #1847

Open deephog opened 2 years ago

deephog commented 2 years ago

Dear Developers,

I am very new to Tensorrt and quantization. Previously I only use the basic example of Tensorrt to generate engines in FP16 because I thought INT8 will compromise accuracy significantly. Until recently I realized there are methods to improve the accuracy of a quantized model, so I started learning how to do it in Tensorrt.

I can successfully convert my segmentation model with ResNet34 as backbone to both FP16 and INT8 without any issue. But when I switch the backbone to efficientnet_b2, the model can still be converted to FP16, but it fails with the following error message when I'm trying to quantize it to INT8. I tried to increase the workspace as some other threads suggested, but it doesn't help.

onnx files of the models mentioned: EfficientNet-backbone model onnx ResNet34-backbone model onnx

[TensorRT] VERBOSE: --------------- Timing Runner: Conv_22 + PWN(PWN(Sigmoid_23), Mul_24) (CudaDepthwiseConvolution) [TensorRT] VERBOSE: CudaDepthwiseConvolution has no valid tactics for this config, skipping [TensorRT] VERBOSE: --------------- Timing Runner: Conv_22 + PWN(PWN(Sigmoid_23), Mul_24) (CudnnConvolution) [TensorRT] VERBOSE: CudnnConvolution has no valid tactics for this config, skipping [TensorRT] VERBOSE: --------------- Timing Runner: Conv_22 + PWN(PWN(Sigmoid_23), Mul_24) (CaskConvolution) [TensorRT] VERBOSE: CaskConvolution has no valid tactics for this config, skipping [TensorRT] VERBOSE: *************** Autotuning format combination: Half(65536,1:16,256,1) -> Half(65536,1:16,256,1) *************** [TensorRT] VERBOSE: --------------- Timing Runner: Conv_22 + PWN(PWN(Sigmoid_23), Mul_24) (CudnnConvolution) [TensorRT] VERBOSE: CudnnConvolution has no valid tactics for this config, skipping [TensorRT] VERBOSE: --------------- Timing Runner: Conv_22 + PWN(PWN(Sigmoid_23), Mul_24) (CaskConvolution) [TensorRT] VERBOSE: CaskConvolution has no valid tactics for this config, skipping [TensorRT] VERBOSE: *************** Autotuning format combination: Half(65536,65536:32,256,1) -> Half(65536,65536:32,256,1) *************** [TensorRT] VERBOSE: --------------- Timing Runner: Conv_22 + PWN(PWN(Sigmoid_23), Mul_24) (CudnnConvolution) [TensorRT] VERBOSE: CudnnConvolution has no valid tactics for this config, skipping [TensorRT] VERBOSE: --------------- Timing Runner: Conv_22 + PWN(PWN(Sigmoid_23), Mul_24) (CaskConvolution) [TensorRT] VERBOSE: CaskConvolution has no valid tactics for this config, skipping [TensorRT] VERBOSE: Deleting timing cache: 881 entries, 1138 hits [TensorRT] INFO: [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +0, GPU +0, now: CPU 4763, GPU 9038 (MiB) [TensorRT] ERROR: 10: [optimizer.cpp::computeCosts::1855] Error Code 10: Internal Error (Could not find any implementation for node Conv_22 + PWN(PWN(Sigmoid_23), Mul_24).) [TensorRT] ERROR: 2: [builder.cpp::buildSerializedNetwork::417] Error Code 2: Internal Error (Assertion enginePtr != nullptr failed.) Completed creating Engine Traceback (most recent call last): File "basnet_test.py", line 231, in <module> main(batch_size=bs, image_size=(256, 256), proc_size=(256, 256), input_size=(720, 1280), model_name='effnet_b2' File "basnet_test.py", line 175, in main f.write(serialized_engine) TypeError: a bytes-like object is required, not 'NoneType'

The code I used to generate the engine is almost the same as the official example:

https://github.com/deephog/code/blob/61f3d9f445296ebd7fe2593b7f6ba034017471cc/question#L1-L26

I tried to have FP16 flag up, to give the engine more flexibility, but the error still persisted.

Please share any of your thoughts or suggestions. Much appreciated!

ttyio commented 2 years ago

@deephog , what's the trt version and GPU are you using, have you tried latest 8.4 EA? thanks

deephog commented 2 years ago

@deephog , what's the trt version and GPU are you using, have you tried latest 8.4 EA? thanks

I downloaded a docker image from NGC 2 or 3 months ago which uses a TRT version 8.0.3.4 and I'm building on a 3090. Do you think upgrading the TRT version may fix this?

deephog commented 2 years ago

@deephog , what's the trt version and GPU are you using, have you tried latest 8.4 EA? thanks

I updated Tensorrt to 8.4EA and it did solve the problem! Now the engine compiling is successful.

However, the compiled engine runs in a speed that is absurdly slow.

If I compile the model in FP16 alone, it runs about 11ms per batch, even if I compile it in FP32, it can still run at 24ms per batch. But when I turn on INT8 flag (even with FP16 turned on together), the engine runs more than 1 second per batch !!

I upload all 3 engines here: FP16 FP32 FP16&INT8

you can see that the FP32 engine is about double the size of the FP16 engine. But the FP16/INT8 hybrid engine is only slightly larger than FP16 engine, so I cannot explain it as falling back to FP32 (I cannot anyways, because it is hundred times slower than FP32). It almost feels like running on a CPU, but there is no way a Tensorrt engine being mistakenly ran on CPU right? Since it is cuda based.

Besides this issue, I wanted to ask another question: Is it necessary to do the quantization with tensorrt? can I quantize the model in pytorch, then compile it directly into INT8 engine? The second way will be much easier, since majority of the tensorrt quantization information is in C++,which I have almost zero knowledge. If it is possible, please show me some hint of how to do it.

Thank you!

ttyio commented 2 years ago

@deephog , how did you calibrate the model now when you build int8 engine? did you using setDynamicRange or use your custom calibrator?

Is it necessary to do the quantization with tensorrt? can I quantize the model in pytorch, then compile it directly into INT8 engine?

We have a pytorch based QAT tool in https://github.com/NVIDIA/TensorRT/tree/main/tools/pytorch-quantization

As a first step, could you try trtexec and attach the log here, below command line would assign some random INT8 dynamic range so we can get the INT8 perf data. thanks!

  trtexec --onnx=your_onnx_file --verbose --noDataTransfers  --separateProfileRun --dumpProfile --useCudaGraph > run_fp32.log
  trtexec --onnx=your_onnx_file --fp16 --verbose --noDataTransfers  --separateProfileRun --dumpProfile --useCudaGraph > run_fp16.log
  trtexec --onnx=your_onnx_file --int8 --fp16 --verbose --noDataTransfers  --separateProfileRun --dumpProfile --useCudaGraph > run_fp16_int8.log
deephog commented 2 years ago

@deephog , how did you calibrate the model now when you build int8 engine? did you using setDynamicRange or use your custom calibrator?

Is it necessary to do the quantization with tensorrt? can I quantize the model in pytorch, then compile it directly into INT8 engine?

We have a pytorch based QAT tool in https://github.com/NVIDIA/TensorRT/tree/main/tools/pytorch-quantization

As a first step, could you try trtexec and attach the log here, below command line would assign some random INT8 dynamic range so we can get the INT8 perf data. thanks!

  trtexec --onnx=your_onnx_file --verbose --noDataTransfers  --separateProfileRun --dumpProfile --useCudaGraph > run_fp32.log
  trtexec --onnx=your_onnx_file --fp16 --verbose --noDataTransfers  --separateProfileRun --dumpProfile --useCudaGraph > run_fp16.log
  trtexec --onnx=your_onnx_file --int8 --fp16 --verbose --noDataTransfers  --separateProfileRun --dumpProfile --useCudaGraph > run_fp16_int8.log

Thank you! I will try both of your suggestions and come back later with the results!

deephog commented 2 years ago

@deephog , how did you calibrate the model now when you build int8 engine? did you using setDynamicRange or use your custom calibrator?

Is it necessary to do the quantization with tensorrt? can I quantize the model in pytorch, then compile it directly into INT8 engine?

We have a pytorch based QAT tool in https://github.com/NVIDIA/TensorRT/tree/main/tools/pytorch-quantization

As a first step, could you try trtexec and attach the log here, below command line would assign some random INT8 dynamic range so we can get the INT8 perf data. thanks!

  trtexec --onnx=your_onnx_file --verbose --noDataTransfers  --separateProfileRun --dumpProfile --useCudaGraph > run_fp32.log
  trtexec --onnx=your_onnx_file --fp16 --verbose --noDataTransfers  --separateProfileRun --dumpProfile --useCudaGraph > run_fp16.log
  trtexec --onnx=your_onnx_file --int8 --fp16 --verbose --noDataTransfers  --separateProfileRun --dumpProfile --useCudaGraph > run_fp16_int8.log

I got result logs here: FP32 FP16 FP16&INT8

If I read the logs correctly, it seems like the inference time is indeed decreasing when INT8 is on, which means the process is successful.

For your other question, I modified the code from this repo to create my own calibrator because I really had a difficult time finding calibration related info that is in Python.

For the Pytorch quantization tool, can I understand the process this way: I first quantize the model with the tool you provided, then it will give me a file that defines the dynamic range of each layer, then I use "trtexe" with the dynamic range files to create the engine? Please hint me with what I should do with the outcome of the Pytorch quantization tool If I'm wrong. Thanks!

deephog commented 2 years ago

@deephog , how did you calibrate the model now when you build int8 engine? did you using setDynamicRange or use your custom calibrator?

Is it necessary to do the quantization with tensorrt? can I quantize the model in pytorch, then compile it directly into INT8 engine?

We have a pytorch based QAT tool in https://github.com/NVIDIA/TensorRT/tree/main/tools/pytorch-quantization

As a first step, could you try trtexec and attach the log here, below command line would assign some random INT8 dynamic range so we can get the INT8 perf data. thanks!

  trtexec --onnx=your_onnx_file --verbose --noDataTransfers  --separateProfileRun --dumpProfile --useCudaGraph > run_fp32.log
  trtexec --onnx=your_onnx_file --fp16 --verbose --noDataTransfers  --separateProfileRun --dumpProfile --useCudaGraph > run_fp16.log
  trtexec --onnx=your_onnx_file --int8 --fp16 --verbose --noDataTransfers  --separateProfileRun --dumpProfile --useCudaGraph > run_fp16_int8.log

Hi Xiaodong, I went through the tutorial you provided for pytorch-quantization. It is quite clear how to do the quantization, and it talked about exporting the ONNX model for tensorrt. However, it didn't provide an example of how to compile an engine from the quantized ONNX.

From what I read in the guide of Tensorrt, if you want to compile the engine with pre-quantized model, you need to provide the fixed range profile for each layer. I didn't see any option that directly read a quantized or partially quantized ONNX model and compile it as it is, without providing extra information.

Could you please share a little more information of how to do it?

Thanks!