openxla / stablehlo

Backward compatible ML compute opset inspired by HLO/MHLO
Apache License 2.0
390 stars 103 forks source link

Fix legalization of quantized Min/Max op to Int ops #2395

Closed sdasgup3 closed 3 months ago

sdasgup3 commented 3 months ago

Problem

Consider the following quantized stablehlo max operation

func.func @max_per_tensor_diff_quant_parameters(%arg0: tensor<!quant.uniform<i8:f32,1.0:0>>, %arg1: tensor<!quant.uniform<i8:f32,2.0:1>>) ->  tensor<!quant.uniform<i8:f32,3.0:2>> {
  %0 = "stablehlo.maximum"(%arg0, %arg1) : (tensor<!quant.uniform<i8:f32,1.0:0>>, tensor<!quant.uniform<i8:f32,2.0:1>>) -> tensor<!quant.uniform<i8:f32,3.0:2>>
  func.return %0 : tensor<!quant.uniform<i8:f32,3.0:2>>
}

Note that the quantization parameters (scales and zps) for operands and results are different.

Currently --stablehlo-legalize-quant-to-int legalizes this as follows:

func.func @max_per_tensor_diff_quant_parameters(%arg0: tensor<i8>, %arg1:  tensor<i8>) ->   tensor<i8> { 
  %0 = stablehlo.maximum %arg0, %arg1:  tensor<i8>
  func.return %0 :  tensor<i8>

which only makes sense when the quantization parameters is same for all operands and results, otherwise it is buggy (why).

Proposed solution

The PR adds a check to allow the legalization only when "the quantization parameters is same for all operands and results". Else error out.

Appendix: Reason that it is a problem

Consider a and b are the input quantized values with scale/ zp as [s1, z1] and [s2, z2] resp, and c is the output with scal/zp as [s3, z3]

we have from the semantics of the operation

(c - z3)*s3 = max[ (a-z1)*s1, (a-z2)*s2 ] 

Only if z1=z2=z3, and s1=s2=s3, we can simplify the above as c = Max(a,b) which is what the current legalization is doing.