microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.43k stars 2.9k forks source link

QDQ + Add nodes are not fused into QLinearAdd when the graph is optimized #12487

Open fxmarty opened 2 years ago

fxmarty commented 2 years ago

Describe the bug

Following https://github.com/microsoft/onnxruntime/issues/11599 , I am looking at what happens during inference with an InferenceSession, feeding a QDQ model, using CPUExecutionProvider.

Urgency None

System information

To Reproduce ONNX original and quantized model can be found here, as well as optimized model after quantization: https://huggingface.co/fxmarty/sst2-onnx/tree/main

import onnxruntime

options = onnxruntime.SessionOptions()

model_path = "/path/to/model-quantized.onnx"

options.optimized_model_filepath = "quantized_model_optimized.onnx"

session = onnxruntime.InferenceSession(
    model_path,
    sess_options=options,
)

will create a quantized_model_optimized.onnx which is the model used at run time.

In the model-quantized.onnx, we have the following nodes:

image

In the quantized_model_optimized.onnx, they result in:

image

As we can see, the MatMul has been transformed into a QLinearMatMul to use arithmetic on integers. However, the Add operation is left as is, meaning as well that we have left DequantizeLinear operations.

The result is that, basically when we do benchmarks running a statically quantized model with onnxruntime is slower than a dynamic model, in my opinion due to the leftovers DequantizeLinear:

image

Expected behavior Avoid too many DequantizeLinear, e.g. by using QLinearAdd nodes ( https://github.com/microsoft/onnxruntime/blob/master/docs/ContribOperators.md#com.microsoft.QLinearAdd ).

edgchen1 commented 2 years ago

image

Currently ORT requires that the data types of the DQ inputs are the same in order for them to be fused into a QLinearAdd. Looking at a particular case, one of the inputs is uint8 and the other is int8.

edgchen1 commented 2 years ago

Though looking at the model-quantized.onnx, the graph starts with int8 values. @yufenglee perhaps QDQS8ToU8Transformer is preventing some QDQ fusions in this case?

fxmarty commented 2 years ago

Thank you! How do you know whether the output of QuantizeLinear is uint8 or int8, looking at the onnx file? The zero-point being non-zero?

edgchen1 commented 2 years ago

Right. In this case, you can see the type of the zero point initializer in Netron.

yufenglee commented 2 years ago

We need special handle on x64 for cases that one input of Add, Mul is weight.

fxmarty commented 2 years ago

@yufenglee I am open to contributing to this if you give me any direction.

yufenglee commented 2 years ago

@fxmarty , that's great! One solution is to refine the logic when to convert s8 weight to u8. Currently, it only converts s8 weight to u8 when kOrtSessionOptionsAvx2PrecisionMode is set to true(or '1') on x64. It doesn't handle the case for Add, Mul and other ops. Actually, if the consumers of the output of the DQ node are not MatMul, Conv, Gemm, it should also convert the s8 weight to u8 too. So we can change the logic to:if(QDQ::MatchDQNode(node) && ShouldConvertWeightFromS8ToU8(graph, node)). For ShouldConvertWeightFromS8ToU8, its logic will be like:

https://github.com/microsoft/onnxruntime/blob/819c36701f066e2e37a86fa7c4a169ec80b0e374/onnxruntime/core/optimizer/qdq_transformer/qdq_s8_to_u8.cc#L106

bilelomrani1 commented 1 year ago

Is there any news on this? It seems to be a very nice and important fix.

fxmarty commented 1 year ago

@bilelomrani1 I'm happy to have a second look to it, @yufenglee if you can guide: https://github.com/microsoft/onnxruntime/pull/12631#issuecomment-1225788746