openxla / stablehlo

Backward compatible ML compute opset inspired by HLO/MHLO
Apache License 2.0
357 stars 100 forks source link

[Reference interpreter] Remove need for canonicalizing away shape patterns #2390

Open sdasgup3 opened 3 weeks ago

sdasgup3 commented 3 weeks ago

For evaluating quantized program, the Stablehlo reference interpreter depends on StablehloLegalizeQuantToInt pass which introduces Chlo broadcast operations for scale multiplication/division and zero-point addition. Legalizing the Chlo operations to StableHLO operations amount to including shape operations which needs to be canonicalized away using a bunch of canonicalization passes.

We believe that for statically shaped program we can avoid the need for chlo broadcast operations altogether and that would simply the decomposition pipeline for quantized operations.

For example, the following program

func.func @quantized_add() -> tensor<2xf32> {
  %operand1 = stablehlo.constant dense<[1.0, 2.0]> : tensor<2xf32>
  %operand2 = stablehlo.constant dense<[3.0, 4.0]> : tensor<2xf32>
  %q_operand1 = "stablehlo.uniform_quantize"(%operand1) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32, 0.1:-30>>
  %q_operand2 = "stablehlo.uniform_quantize"(%operand2) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32, 0.5:-20>>
  %result = "stablehlo.add"(%q_operand1, %q_operand2) : (tensor<2x!quant.uniform<i8:f32, 0.1:-30>>, tensor<2x!quant.uniform<i8:f32, 0.5:-20>>) -> tensor<2x!quant.uniform<i8:f32, 0.5:-20>>
  %result_f = "stablehlo.uniform_dequantize"(%result) : (tensor<2x!quant.uniform<i8:f32, 0.5:-20>>) -> tensor<2xf32>
  func.return %result_f: tensor<2xf32>
}

needs to go through the following passes to convert to a fully stablehlo program

--stablehlo-legalize-quant-to-int --chlo-legalize-to-stablehlo --canonicalize --shape-legalize-to-stablehlo --stablehlo-canonicalize-dynamism