microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.84k stars 2.94k forks source link

[Quantization] How to add QDQ pairs to the bias of conv and gemm operators? #14229

Closed HYLcool closed 1 year ago

HYLcool commented 1 year ago

Describe the issue

Hi, I tried to use QDQ Format to quantize my onnx model and use trtexec to benchmark its inference speed. And I met a problem similar to #11535. After I add extra_options={'AddQDQPairToWeight': True} to quantize_static, the quantized model still fails to run on TRT and returns errors like this:

[01/11/2023-12:07:32] [E] Error[3]: 476_DequantizeLinear: only activation types allowed as input to this layer.
[01/11/2023-12:07:32] [E] [TRT] ModelImporter.cpp:773: While parsing node number 1 [DequantizeLinear -> "476"]:
[01/11/2023-12:07:32] [E] [TRT] ModelImporter.cpp:774: --- Begin node ---
[01/11/2023-12:07:32] [E] [TRT] ModelImporter.cpp:775: input: "476_quantized"
input: "476_quantized_scale"
input: "476_quantized_zero_point"
output: "476"
name: "476_DequantizeLinear"
op_type: "DequantizeLinear"

[01/11/2023-12:07:32] [E] [TRT] ModelImporter.cpp:776: --- End node ---
[01/11/2023-12:07:32] [E] [TRT] ModelImporter.cpp:779: ERROR: ModelImporter.cpp:180 In function parseGraph:
[6] Invalid Node - 476_DequantizeLinear
476_DequantizeLinear: only activation types allowed as input to this layer.
[01/11/2023-12:07:32] [E] Failed to parse onnx file
[01/11/2023-12:07:32] [I] Finish parsing network model
[01/11/2023-12:07:32] [E] Parsing model failed
[01/11/2023-12:07:32] [E] Failed to create engine from model or file.
[01/11/2023-12:07:32] [E] Engine set up failed

I found that there are QDQ pairs after FP32 weights but there is still only a DQ op after quantized bias (see the figure below). That may be the reason why this error occurs.

image

So I wonder if the quantization in onnxruntime supports adding QDQ pairs to bias like AddQDQPairToWeight ?

To reproduce

The onnx model mentioned above is a mobilenet-v2 model obtained from the onnx model zoo link.

To reproduce: (similar to the example from here)

  1. pre-processing:
    python -m onnxruntime.quantization.preprocess --input mobilenetv2-7.onnx --output mobilenetv2-7-infer.onnx
  2. quantization:
    
    from onnxruntime.quantization import QuantFormat, QuantType, quantize_static

quantize_static( 'mobilenetv2-opset10-infer.onnx', 'mobilenetv2-opset10-quantized.onnx', data_reader, # from the run.py code quant_format=QuantFormat.QDQ, per_channel=False, weight_type=QuantType.QInt8, activation_type=QuantType.QInt8, optimize_model=False, extra_options={'AddQDQPairToWeight': True})

3. benchmark on TRT:

trtexec --onnx=mobilenetv2-opset10-quantized.onnx --avgRuns=1000 --workspace=1024 --verbose --int8



Then it reports the error messages like above.

### Urgency

_No response_

### Platform

Linux

### OS Version

7.2

### ONNX Runtime Installation

Released Package

### ONNX Runtime Version or Commit ID

1.13.1

### ONNX Runtime API

Python

### Architecture

X64

### Execution Provider

TensorRT

### Execution Provider Library Version

CUDA 11.7 & TensorRT 8.4.1.5
xadupre commented 1 year ago

I wonder if this issue would explain the error message: https://github.com/NVIDIA/TensorRT/issues/2165. Then the activation type should be different for tensorrt.

HYLcool commented 1 year ago

I wonder if this issue would explain the error message: NVIDIA/TensorRT#2165. Then the activation type should be different for tensorrt.

Thanks for your reply! To be honest, I don't really understand the issue you mentioned, and I don't find any operator named Scale from onnx operator list.😢

I found that the inputs of DQ operator for bias are:

And the inputs of DQ operators for weight are:

They are all acceptable for the DQ operator as below.

image

So I think maybe this is not about the dypte of inputs as the issue you mentioned.

Before I added extra_options={'AddQDQPairToWeight': True} to quantize_static, the error message was raised at the DQ operator before weights like this issue #11535. Therefore I think maybe after I add QDQ pairs to bias this problem could be solved.

HYLcool commented 1 year ago

Hi, I tried to write out the calibration table, then I ran trtexec with arg --calib and it loaded quantized engine successfully. So it looks like an optional solution to run quantized onnx model on TensorRT~

But I still wonder if onnxruntime could generate quantized onnx model that can be consumed directly by TensorRT🤔

HYLcool commented 1 year ago

It seems that PR #14549 can solve this problem by removing DQ nodes of bias (DO NOT quantize the bias by setting QuantizeBias to False). Thanks for you guys~