openxla / stablehlo

Backward compatible ML compute opset inspired by HLO/MHLO
Apache License 2.0
390 stars 103 forks source link

Legalize quantized stablehlo operation using uniform_quantize/uniform_dequantize #2394

Closed sdasgup3 closed 3 months ago

sdasgup3 commented 3 months ago

This PR provides with a pass to decompose StableHLO quantized programs using uniform quantize/dequantize operations. For example, the following program

```mlir
func.func @add(%arg0: tensor<!quant.uniform<i8:f32,1.0:0>>, %arg1: tensor<!quant.uniform<i8:f32,2.0:1>>) ->  tensor<!quant.uniform<i8:f32,3.0:2>> {
  %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<!quant.uniform<i8:f32,1.0:0>>, tensor<!quant.uniform<i8:f32,2.0:1>>) -> tensor<!quant.uniform<i8:f32,3.0:2>>
  func.return %0 : tensor<!quant.uniform<i8:f32,3.0:2>>
}
```

Will become:

```mlir
func.func @add(%arg0: tensor<!quant.uniform<i8:f32, 1.000000e+00>>, %arg1: tensor<!quant.uniform<i8:f32, 2.000000e+00:1>>) -> tensor<!quant.uniform<i8:f32, 3.000000e+00:2>> {
  %0 = stablehlo.uniform_dequantize %arg0 : (tensor<!quant.uniform<i8:f32, 1.000000e+00>>) -> tensor<f32>
  %1 = stablehlo.uniform_dequantize %arg1 : (tensor<!quant.uniform<i8:f32, 2.000000e+00:1>>) -> tensor<f32>
  %2 = stablehlo.add %0, %1 : tensor<f32>
  %3 = stablehlo.uniform_quantize %2 : (tensor<f32>) -> tensor<!quant.uniform<i8:f32, 3.000000e+00:2>>
  return %3 : tensor<!quant.uniform<i8:f32, 3.000000e+00:2>>
}

Per the docs/spec.md, following is the exhaustive list of ops which can be interpreted using dq-op-q strategy. The current PR handles all these ops except for un-bolded ones (DotGeneralOp, ConvolutionOp, DynamicConvOp, and AddOp) which are already lowered to integer match using --stablehlo-legalize-quant-to-int pass.

  1. AbsOp
  2. AddOp
  3. Atan2Op
  4. BatchNormGradOp
  5. BatchNormInferenceOp
  6. BatchNormTrainingOp
  7. CbrtOp
  8. CeilOp
  9. CholeskyOp
  10. ClampOp
  11. CompareOp
  12. ConvolutionOp
  13. CosineOp
  14. DivOp
  15. DotGeneralOp
  16. DynamicConvOp
  17. Expm1Op
  18. ExpOp
  19. FloorOp
  20. Log1pOp
  21. LogisticOp
  22. LogOp
  23. MaxOp
  24. MinOp
  25. MulOp
  26. NegOp
  27. PowOp
  28. ReducePrecisionOp
  29. RemOp
  30. RoundOp
  31. RoundNearestEvenOp
  32. RsqrtOp
  33. SelectOp
  34. SignOp
  35. SineOp
  36. SqrtOp
  37. SubtractOp
  38. TanhOp
  39. TriangularSolveOp