microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.7k stars 2.93k forks source link

CPUExecutionProvider outputs wrong value for a quantized model #11532

Open shinh opened 2 years ago

shinh commented 2 years ago

Describe the bug

If you input np.zeros((1, 120, 28, 28)) to this model, the output from CPU mismatches with the one from CUDA. I believe CUDA is right.

Urgency none

System information

To Reproduce

Download http://shinh.skr.jp/t/quant_wrong.onnx and run the following script:

import numpy as np
import onnx
from onnx import numpy_helper
import onnxruntime as ort

model_name = "quant_wrong.onnx"

x = np.zeros((1, 120, 28, 28)).astype(np.float32)

for provider in ["CPUExecutionProvider", "CUDAExecutionProvider"]:
    sess = ort.InferenceSession(model_name, providers=[provider])
    y, = sess.run(["output"], {"input": x})
    print(provider, y.sum())

# Compute the correct answer.
for init in onnx.load(model_name).graph.initializer:
    if init.name == "onnx::Conv_12_quantized":
        quantized_bias = numpy_helper.to_array(init)
    if init.name == "onnx::Conv_12_quantized_scale":
        quantized_scale = numpy_helper.to_array(init)
    if init.name == "onnx::Conv_12_quantized_zero_point":
        quantized_zero_point = numpy_helper.to_array(init)

bias = quantized_bias * quantized_scale + quantized_zero_point
expected = np.fmax(bias, 0)  # Relu
print("expected", expected.sum() * 28 * 28)

I got this result:

CPUExecutionProvider 40.80013
CUDAExecutionProvider 38148.113
expected 38293.33673545718

Expected behavior

The three output values should be similar. I think the output from CPU is wrong.

Additional context

This ONNX model does per_channel quantization only for weight. Bias is quantized per-tensor mannar. I'm guessing CPUExecutionProvider has a wrong assumption on quantized model.

This ONNX was created when I'm debugging https://github.com/microsoft/onnxruntime/issues/11415 . I used the last code snippet I put on the issue to create this ONNX.

yufenglee commented 2 years ago

How did you get this model? scale of bias is assumed to be equal to scale_activation * scale_weight, but it is not the case in this model.

BTW, for this model, CUDA EP won't run it as real quantization. It will run QuantizeLinear, DequantizeLinear and Conv literally.

shinh commented 2 years ago

I created the ONNX after patching onnx_quantizer.py

--- onnx_quantizer.py.orig  2022-05-16 13:28:19.759135197 +0900
+++ onnx_quantizer.py   2022-05-16 13:29:57.574089685 +0900
@@ -594,6 +594,12 @@

         # calcuate scale for bias
         bias_scale = input_scale * weight_scale * beta
+        _, _, bias_zero_point, bias_scale, quantized_data = quantize_data(
+            bias_data.flatten().tolist(),
+            onnx.TensorProto.INT8,
+            True)
+        assert bias_zero_point == 0
+        bias_scale = np.array(bias_scale)

         # quantize bias
         quantized_data = (np.asarray(bias_data) / bias_scale).round().astype(np.int32)

to investigate https://github.com/microsoft/onnxruntime/issues/11415 . I'm guessing sticking with bias_scale = input_scale * weight_scale isn't a great idea for per_channel mode and was trying the above patch. I don't think the above fix is the right one for #11415, but I thought it's worth filing #11532 as a separate issue as it looked like CPU execution provider does not follow ONNX's semantics.

So yes, bias_scale is not input_scale * weight_scale. I agree my ONNX breaks the assumption of backend providers. However, I think this is still a valid ONNX model. For such model, IMHO backend providers should either fallback to slow, fake execution like CUDA EP or raise an exception. Silently outputting wrong values wouldn't be ideal?

dmal-msft commented 9 months ago

Are there any updates on this or any plans to fix this bug?