Open rogerw10 opened 1 year ago
This seems to happen when QlinearConv's activation's data type (uint8) doesn't match the weight's data type (int8). I tried changing the weights' data type and its zero point's data type to uint8, and the onnxruntime inference result was correct.
Hi @xulongwu4, thanks for the insight! Besides uint8, I also tested changed all data types to int8, and that works fine as well. It seems the bug only happens when the data types of input and weights are different, but I feel it is a common case.
@yf711, I tested it with latest version 1.17.1, and the issue is still there. May I know whether it will be fixed? Thanks.
I still see this in ONNXRuntime 1.19.2
I've uploaded another test case that's likely the same bug to https://gist.github.com/mcollinswisc/b1e909bc2bb7a45659df61a71de8aa37 In that script I'm comparing to the ONNX reference implementation from https://github.com/onnx/onnx/blob/d7fbf2ba0cb6df3f8fe326cb6f519a7685ca904f/onnx/reference/ops/op_qlinear_conv.py
Also got this written as a C++ unit test in https://github.com/microsoft/onnxruntime/commit/298d062c802bb49ca1bbe809d8ea4126da8cdf9d
It gives result:
1: [ RUN ] QLinearConvTest.Conv2D_U8S8U8_DifferentInputAndWeightSignedness
1: /home/mcollins/repo/onnxruntime/onnxruntime/test/providers/checkers.cc:271: Failure
1: Expected equality of these values:
1: cur_expected[i]
1: Which is: '\b' (8)
1: cur_actual[i]
1: Which is: '\0'
1: i:4
1: Google Test trace:
1: /home/mcollins/repo/onnxruntime/onnxruntime/test/providers/checkers.cc:568: provider type: CPUExecutionProvider
1: /home/mcollins/repo/onnxruntime/onnxruntime/test/providers/base_tester.cc:830: registered execution providers: CPUExecutionProvider
1:
1: /home/mcollins/repo/onnxruntime/onnxruntime/test/providers/checkers.cc:271: Failure
1: Expected equality of these values:
1: cur_expected[i]
1: Which is: 'X' (88, 0x58)
1: cur_actual[i]
1: Which is: 'I' (73, 0x49)
1: i:5
...
(same values as seen in the Python example's output)
Another observation about the script from https://gist.github.com/mcollinswisc/b1e909bc2bb7a45659df61a71de8aa37: I see the right answer on an Intel CPU, but the wrong answer on an AMD CPU.
For AMD Ryzen 9 5900X 12-Core Processor, it outputs:
[array([[[[ 0, 0, 0, 0, 0],
[73, 0, 0, 35, 45],
[ 0, 0, 0, 0, 32],
[20, 0, 46, 39, 35],
[ 0, 0, 29, 0, 0]],
[[ 0, 0, 0, 0, 0],
[ 0, 0, 0, 0, 0],
[57, 0, 0, 0, 0],
[ 0, 0, 0, 16, 7],
[27, 0, 0, 3, 0]]]], dtype=uint8)]
[array([[[[ 0, 0, 0, 0, 8],
[88, 0, 0, 59, 78],
[ 0, 0, 0, 0, 32],
[20, 0, 55, 39, 35],
[ 0, 0, 36, 0, 0]],
[[ 0, 0, 0, 0, 0],
[ 0, 0, 0, 0, 0],
[57, 0, 0, 0, 0],
[ 0, 0, 0, 16, 7],
[27, 0, 0, 3, 0]]]], dtype=uint8)]
On an 11th Gen Intel (R) Core(TM) i7-11800H @ 2.30GHz, the outputs are the same:
[array([[[[ 0, 0, 0, 0, 8],
[88, 0, 0, 59, 78],
[ 0, 0, 0, 0, 32],
[20, 0, 55, 39, 35],
[ 0, 0, 36, 0, 0]],
[[ 0, 0, 0, 0, 0],
[ 0, 0, 0, 0, 0],
[57, 0, 0, 0, 0],
[ 0, 0, 0, 16, 7],
[27, 0, 0, 3, 0]]]], dtype=uint8)]
[array([[[[ 0, 0, 0, 0, 8],
[88, 0, 0, 59, 78],
[ 0, 0, 0, 0, 32],
[20, 0, 55, 39, 35],
[ 0, 0, 36, 0, 0]],
[[ 0, 0, 0, 0, 0],
[ 0, 0, 0, 0, 0],
[57, 0, 0, 0, 0],
[ 0, 0, 0, 16, 7],
[27, 0, 0, 3, 0]]]], dtype=uint8)]
It looks like, from this call: https://github.com/microsoft/onnxruntime/blob/709368ea1443dc1ff68dece31d692ad065fb94d4/onnxruntime/core/mlas/lib/qgemm_kernel_avx2.cpp#L215
On my AMD system (that gives the wrong answer), it goes into QgemmU8X8KernelAvx2.S, but on my Intel system (that gives the correct answer), it goes into QgemmU8X8KernelAvx512Core.S.
Reverting https://github.com/microsoft/onnxruntime/commit/d5f6343a4afbb1d3ae7acdc881f79bd93cc7c0b5 does not change results
Wrote up a unit test in the mlas tests on QGemm that also reproduces the error, in https://github.com/mcollinswisc/onnxruntime/commit/498579f23f216f31ac238f88c4ae2f0e2c98867c. Inputs in this test come from the input that QLinearConv passes to MlasGemm when running the example from previous comments.
On my AMD system it gives output:
[ RUN ] QGemmU8S8_Int32_NoPack_SingleThread_SpecialCase.CaseFromQLinearConv
/home/mcollins/repo/onnxruntime/onnxruntime/test/mlas/unittest/test_qgemm.h:217: Failure
Expected equality of these values:
C[f]
Which is: -3527
CReference[f]
Which is: -11219
@[0x2x0], Batch=1M=25, N=2, K=9, offa=121, offb=0
Stack trace:
0x635e7985f599: MlasQgemmTest<>::Test()
0x635e7985e498: MlasQgemmFromConvTest::Test()
0x635e7985e6ab: QgemmFromConvExecuteTest::TestBody()
0x635e79931a15: testing::internal::HandleSehExceptionsInMethodIfSupported<>()
0x635e79927917: testing::internal::HandleExceptionsInMethodIfSupported<>()
0x635e79906a5e: testing::Test::Run()
0x635e79907588: testing::TestInfo::Run()
... Google Test internal frames ...
[ FAILED ] QGemmU8S8_Int32_NoPack_SingleThread_SpecialCase.CaseFromQLinearConv, where GetParam() = CaseFromQLinearConv (2 ms)
And on my Intel system the test passes.
Yet smaller test case in https://github.com/mcollinswisc/onnxruntime/commit/5ae70950e65b72e12f9803f4ba64630670ed6d09 makes it clear what is happening.
The input ColumnSumBuffer here: https://github.com/microsoft/onnxruntime/blob/abad69b322512a9372812399c8eb8fe6c7d9a193/onnxruntime/core/mlas/lib/qgemm_kernel_avx2.cpp is equal to -19602 = (127+35)*(-121)
On my AMD system (Ryzen 9 5900X, with AVX2 but not AVX-VNNI or AVX512), when it calls into https://github.com/microsoft/onnxruntime/blob/709368ea1443dc1ff68dece31d692ad065fb94d4/onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx2.S, the result is 13165, instead of the correct answer 20653. 13165 is equal to ((1 << 15) - 1) - 19602.
I guess when computing 127*250+35*243
, the vpmaddubsw
instruction is saturating, giving the max value for int16 as its result, from which 19602 is subtracted.
I think it is the vpmaddsubsw
instruction at this line where I am seeing it saturate:
https://github.com/microsoft/onnxruntime/blob/abad69b322512a9372812399c8eb8fe6c7d9a193/onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx2.S#L123
Though there are vpmaddubsw
instructions for block sizes that can, I think, cause the same error:
https://github.com/microsoft/onnxruntime/blob/abad69b322512a9372812399c8eb8fe6c7d9a193/onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx2.S#L127
https://github.com/microsoft/onnxruntime/blob/abad69b322512a9372812399c8eb8fe6c7d9a193/onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx2.S#L73
https://github.com/microsoft/onnxruntime/blob/abad69b322512a9372812399c8eb8fe6c7d9a193/onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx2.S#L77
I think I have also just come across this bug, here is my MRE
onnx == 1.17.0
onnxruntime == 1.18.0
import io
import numpy as np
import onnx
import onnx.numpy_helper as numpy_helper
import onnxruntime as ort
INPUT_SCALE = np.array([0.01], dtype=np.float32)
INPUT_ZERO = np.array([0], dtype=np.uint8)
INPUT_QUANTIZED = np.array([[210, 198, 234]], dtype=np.uint8)
WEIGHT_SCALE = np.array([0.01], dtype=np.float32)
WEIGHT_ZERO = np.array([0], dtype=np.int8)
# for the first column below values of -127 to -81 give the wrong result
WEIGHT_QUANTIZED = np.array([[-81, -80, 127]], dtype=np.int8)
BIAS_QUANTIZED = np.array([0], dtype=np.int32)
def build_onnx_graph() -> io.BytesIO:
input_tensor = numpy_helper.from_array(INPUT_QUANTIZED, name="input")
weight_tensor = numpy_helper.from_array(WEIGHT_QUANTIZED, name="weight")
bias_tensor = numpy_helper.from_array(BIAS_QUANTIZED, name="bias")
i_z = numpy_helper.from_array(INPUT_ZERO, name="input_zero")
w_z = numpy_helper.from_array(WEIGHT_ZERO, name="weight_zero")
i_s = numpy_helper.from_array(INPUT_SCALE, name="input_scale")
w_s = numpy_helper.from_array(WEIGHT_SCALE, name="weight_scale")
initers = [input_tensor, weight_tensor, bias_tensor, i_z, w_z, i_s, w_s]
nodes = [
onnx.helper.make_node(
"QGemm",
["input", "input_scale", "input_zero", "weight", "weight_scale", "weight_zero", "bias"],
["output"],
transB=1,
domain="com.microsoft",
),
]
output = onnx.helper.make_tensor_value_info(
"output", onnx.TensorProto.FLOAT, [INPUT_QUANTIZED.shape[0], WEIGHT_QUANTIZED.shape[0]]
)
graph = onnx.helper.make_graph(
nodes=nodes,
name="mre_linear",
inputs=[],
outputs=[output],
initializer=initers,
)
model = onnx.helper.make_model(
graph,
opset_imports=[
onnx.helper.make_opsetid("com.microsoft", 1),
onnx.helper.make_opsetid("", 19),
],
)
model_str = model.SerializeToString()
return io.BytesIO(model_str)
def expected_convert_to_float_and_matmul() -> np.ndarray:
weight = (WEIGHT_QUANTIZED - WEIGHT_ZERO) * WEIGHT_SCALE
input = (INPUT_QUANTIZED - INPUT_ZERO) * INPUT_SCALE
bias = BIAS_QUANTIZED * INPUT_SCALE * WEIGHT_SCALE
return np.matmul(input, weight.T) + bias
def expected_quantized_to_matmul_to_float() -> np.ndarray:
# convert to int32
input = INPUT_QUANTIZED.astype(np.int32) - INPUT_ZERO.astype(np.int32)
weight = WEIGHT_QUANTIZED.astype(np.int32) - WEIGHT_ZERO.astype(np.int32)
bias = BIAS_QUANTIZED
quantized_result = np.matmul(input, weight.T) + bias
quantized_result = quantized_result.astype(np.float64)
float_result = quantized_result * INPUT_SCALE * WEIGHT_SCALE
return float_result
def onnx_output() -> np.ndarray:
file = build_onnx_graph()
opts = ort.SessionOptions()
opts.inter_op_num_threads = 1
opts.intra_op_num_threads = 1
sess = ort.InferenceSession(file.read(), sess_options=opts)
return sess.run(["output"], {})[0]
def test_broken() -> None:
# what i am expecting to get out
expected_to_f_to_out = expected_convert_to_float_and_matmul()
expected_q_to_out = expected_quantized_to_matmul_to_float()
got = onnx_output()
print("fm_f", expected_to_f_to_out)
print("qm_f", expected_q_to_out)
print("onnx", got)
if __name__ == "__main__":
np.set_printoptions(precision=5, linewidth=np.inf)
test_broken()
It looks like this can be worked around by enabling session.x64quantprecision
(aka kOrtSessionOptionsAvx2PrecisionMode
):
opts.add_session_config_entry("session.x64quantprecision", "1")
If I add that setting then the initial example reported by @rogerw10 as well as the repros from @mcollinswisc and @ben-da6 give the expected results.
This setting was added in #12088, and that PR has an explanation of why this problem happens.
It seems like the wrong choice that this defaults to off though, surely the more correct approach should be the default, and if people want faster performance and understand the risk, they could opt in to disabling this? That would save a lot of hassle for people who run into this and then need to figure out what is going wrong. @chenfucn, should this default be changed?
Describe the issue
I have a sample ONNX file with a QLinearConv block as the attached file. When running it with a specific input using onnxruntime, the inference output is different from what is expected. In the example code below, the expected output is 103 but got 92.
It seems no issue with smaller input data, but I cannot figure out why it would overflow. Thanks.
model.zip
To reproduce
Urgency
No response
Platform
Linux
OS Version
Ubuntu 18.04.6 LTS
ONNX Runtime Installation
Released Package
ONNX Runtime Version or Commit ID
1.13.1
ONNX Runtime API
Python
Architecture
X64
Execution Provider
CPU
Execution Provider Library Version
PYTHON 3.8.10