openxla / stablehlo

Backward compatible ML compute opset inspired by HLO/MHLO
Apache License 2.0
376 stars 103 forks source link

Enhance `quant-to-int` with `qdq` Fallback #2456

Closed sdasgup3 closed 1 month ago

sdasgup3 commented 1 month ago

stablehlo-legalize-quant-to-int cs has been used for decomposing quantized stablehlo programs for purposes like interpretation XLA compilation.

However, the pass has partial support for a selected ops (add, dot_general, dot, convolution, max. min) to decomposed their quantized versions to integer math.

For example, the following is supported

# cat test.mlir
func.func @max_per_tensor_same_quant_parameters(
    %arg0: tensor<1x2x!quant.uniform<i8:f32, 2.000000e+00:3>>
  ) -> tensor<1x2x!quant.uniform<i8:f32, 2.000000e+00:3>> {
  %0 = "stablehlo.maximum"(%arg0, %arg0) : (
    tensor<1x2x!quant.uniform<i8:f32, 2.000000e+00:3>>,
    tensor<1x2x!quant.uniform<i8:f32, 2.000000e+00:3>>
  ) -> tensor<1x2x!quant.uniform<i8:f32, 2.000000e+00:3>>
  return %0 : tensor<1x2x!quant.uniform<i8:f32, 2.000000e+00:3>>
}

$ stablehlo-opt test.mlir --stablehlo-legalize-quant-to-math
 func.func @max_per_tensor_same_quant_parameters(%arg0: tensor<1x2xi8>) -> tensor<1x2xi8> {
    %0 = stablehlo.maximum %arg0, %arg0 : tensor<1x2xi8>
    return %0 : tensor<1x2xi8>
  }

However, the following valid program errors out with error: stablehlo.maximum with different quantization parameters for operands and results is not supported.: An unsupported feature of the pass.

func.func @max_per_tensor_diff_quant_parameters(%arg0: tensor<!quant.uniform<i8:f32,1.0:0>>, %arg1: tensor<!quant.uniform<i8:f32,2.0:1>>) ->  tensor<!quant.uniform<i8:f32,3.0:2>> {
  %0 = "stablehlo.maximum"(%arg0, %arg1) : (tensor<!quant.uniform<i8:f32,1.0:0>>, tensor<!quant.uniform<i8:f32,2.0:1>>) -> tensor<!quant.uniform<i8:f32,3.0:2>>
  func.return %0 : tensor<!quant.uniform<i8:f32,3.0:2>>
}

Moreover, we have a pass stablehlo-legalize-quantized-op-to-qdq which can be ideally be used as a fallback convertng the above to

func.func @max_per_tensor_diff_quant_parameters(%arg0: tensor<!quant.uniform<i8:f32, 1.000000e+00>>, %arg1: tensor<!quant.uniform<i8:f32, 2.000000e+00:1>>) -> tensor<!quant.uniform<i8:f32, 3.000000e+00:2>> {
    %0 = stablehlo.uniform_dequantize %arg0 : (tensor<!quant.uniform<i8:f32, 1.000000e+00>>) -> tensor<f32>
    %1 = stablehlo.uniform_dequantize %arg1 : (tensor<!quant.uniform<i8:f32, 2.000000e+00:1>>) -> tensor<f32>
    %2 = stablehlo.maximum %0, %1 : tensor<f32>
    %3 = stablehlo.uniform_quantize %2 : (tensor<f32>) -> tensor<!quant.uniform<i8:f32, 3.000000e+00:2>>
    return %3 : tensor<!quant.uniform<i8:f32, 3.000000e+00:2>>
  }

However, the fact that stablehlo-legalize-quant-to-int errors out for certain scenarios prohibits the use of stablehlo-legalize-quantized-op-to-qdq as a fallback.

The goal of this issue is to enhance the stablehlo-legalize-quant-to-int pass to gracefully handle scenarios where it cannot directly decompose quantized operations. In such cases, it should allow the stablehlo-legalize-quantized-op-to-qdq pass to be used as a fallback to ensure the overall transformation pipeline continues without errors.

sdasgup3 commented 1 month ago

We have the following PRs (in order) to provide the solution.

  1. Merge qdq and quant-to-int passes #2458
  2. Support qdq decomposition of TanOp #2459
  3. Remove type-inference dependency while craete qdq pattarns #2460
  4. Support qdq decomposition of DotGeneralOp and ConvolutionOp #2461
  5. Remove the qunt-to-math pass limitation of Dot/Conv op result type #2462
sdasgup3 commented 1 month ago

closing this issue with all the PRs merged.