Closed sdasgup3 closed 3 months ago
Consider the following quantized stablehlo max operation
func.func @max_per_tensor_diff_quant_parameters(%arg0: tensor<!quant.uniform<i8:f32,1.0:0>>, %arg1: tensor<!quant.uniform<i8:f32,2.0:1>>) -> tensor<!quant.uniform<i8:f32,3.0:2>> { %0 = "stablehlo.maximum"(%arg0, %arg1) : (tensor<!quant.uniform<i8:f32,1.0:0>>, tensor<!quant.uniform<i8:f32,2.0:1>>) -> tensor<!quant.uniform<i8:f32,3.0:2>> func.return %0 : tensor<!quant.uniform<i8:f32,3.0:2>> }
Note that the quantization parameters (scales and zps) for operands and results are different.
Currently --stablehlo-legalize-quant-to-int legalizes this as follows:
func.func @max_per_tensor_diff_quant_parameters(%arg0: tensor<i8>, %arg1: tensor<i8>) -> tensor<i8> { %0 = stablehlo.maximum %arg0, %arg1: tensor<i8> func.return %0 : tensor<i8>
which only makes sense when the quantization parameters is same for all operands and results, otherwise it is buggy (why).
The PR adds a check to allow the legalization only when "the quantization parameters is same for all operands and results". Else error out.
Consider a and b are the input quantized values with scale/ zp as [s1, z1] and [s2, z2] resp, and c is the output with scal/zp as [s3, z3]
a
b
[s1, z1]
[s2, z2]
c
[s3, z3]
we have from the semantics of the operation
(c - z3)*s3 = max[ (a-z1)*s1, (a-z2)*s2 ]
Only if z1=z2=z3, and s1=s2=s3, we can simplify the above as c = Max(a,b) which is what the current legalization is doing.
z1=z2=z3
s1=s2=s3
c = Max(a,b)
Problem
Consider the following quantized stablehlo max operation
Note that the quantization parameters (scales and zps) for operands and results are different.
Currently --stablehlo-legalize-quant-to-int legalizes this as follows:
which only makes sense when the quantization parameters is same for all operands and results, otherwise it is buggy (why).
Proposed solution
The PR adds a check to allow the legalization only when "the quantization parameters is same for all operands and results". Else error out.
Appendix: Reason that it is a problem
Consider
a
andb
are the input quantized values with scale/ zp as[s1, z1]
and[s2, z2]
resp, andc
is the output with scal/zp as[s3, z3]
we have from the semantics of the operation
Only if
z1=z2=z3
, ands1=s2=s3
, we can simplify the above asc = Max(a,b)
which is what the current legalization is doing.