Open Nullkooland opened 6 months ago
Thanks for raising this issue @Nullkooland. @lsy323 can you please have a look?
Hi @Nullkooland, thank you for reporting the issue!
Upstream introduced a BC breaking change, in which the fp->quant
pair will be folded by default. As you mentioned Qparam
+ DeQuant
is not supported now, so you'll need to set fold_quantize=False
. The repro script passed on my end after disabling the quant weight folding.
Also please don't forget to set STABLEHLO_BYTECODE_FROM_PRETTYPRINT=1
(As mentioned in https://github.com/pytorch/xla/pull/5763), until the related StableHLO issue is resolved.
Hi @Nullkooland, thank you for reporting the issue!
Upstream introduced a BC breaking change, in which the
fp->quant
pair will be folded by default. As you mentionedQparam
+DeQuant
is not supported now, so you'll need to setfold_quantize=False
. The repro script passed on my end after disabling the quant weight folding.Also please don't forget to set
STABLEHLO_BYTECODE_FROM_PRETTYPRINT=1
(As mentioned in #5763), until the related StableHLO issue is resolved.
@lsy323 Thanks for your reply!
Will the QParam
+ DeQuant
QDQ case be supported in the future?
So that the exported StableHLO would look like:
module @ExampleQDQModel {
func.func @main(
%weight_0_q: tensor<32x3x3x3x!quant.uniform<i8:f32, 1.000000e+00>>,
// Folded QParams.
%input: tensor<1x3x64x64xf32>
) {
%input_q = stablehlo.uniform_quantize %input : (tensor<1x3x64x64xf32>) -> tensor<1x3x64x64x!quant.uniform<i8:f32, 1.000000e+00>>
%input_dq = stablehlo.uniform_dequantize %input_q : (tensor<1x3x64x64x!quant.uniform<i8:f32, 1.000000e+00>>) -> tensor<1x3x64x64xf32>
%weight_0_dq = stablehlo.uniform_dequantize %weight_0_q : (tensor<32x3x3x3x!quant.uniform<i8:f32, 1.000000e+00>>) -> tensor<32x3x3x3xf32>
%conv_0 = stablehlo.convolution(%input_dq, %weight_0_dq) {...} : (tensor<1x3x64x64xf32>, tensor<32x3x3x3xf32>) -> tensor<1x32x64x64xf32>
%conv_0_q = stablehlo.uniform_quantize %conv_0 : (tensor<1x3x64x64xf32>) -> tensor<1x3x64x64x!quant.uniform<i8:f32, 1.000000e+00>>
%conv_0_dq = stablehlo.uniform_dequantize %conv_0_q : (tensor<1x3x64x64x!quant.uniform<i8:f32, 1.000000e+00>>) -> tensor<1x3x64x64xf32>
...
}
}
so the quantized params could be exported directly to reduce the size of the exported artifacts.
Hi @Nullkooland, thank you for reporting the issue!
Upstream introduced a BC breaking change, in which the
fp->quant
pair will be folded by default. As you mentionedQparam
+DeQuant
is not supported now, so you'll need to setfold_quantize=False
. The repro script passed on my end after disabling the quant weight folding.Also please don't forget to set
STABLEHLO_BYTECODE_FROM_PRETTYPRINT=1
(As mentioned in #5763), until the related StableHLO issue is resolved.
Hi @Nullkooland, just FYI STABLEHLO_BYTECODE_FROM_PRETTYPRINT=1
is not needed anymore now.
Hi @Nullkooland, thank you for reporting the issue! Upstream introduced a BC breaking change, in which the
fp->quant
pair will be folded by default. As you mentionedQparam
+DeQuant
is not supported now, so you'll need to setfold_quantize=False
. The repro script passed on my end after disabling the quant weight folding. Also please don't forget to setSTABLEHLO_BYTECODE_FROM_PRETTYPRINT=1
(As mentioned in #5763), until the related StableHLO issue is resolved.@lsy323 Thanks for your reply!
Will the
QParam
+DeQuant
QDQ case be supported in the future? So that the exported StableHLO would look like:module @ExampleQDQModel { func.func @main( %weight_0_q: tensor<32x3x3x3x!quant.uniform<i8:f32, 1.000000e+00>>, // Folded QParams. %input: tensor<1x3x64x64xf32> ) { %input_q = stablehlo.uniform_quantize %input : (tensor<1x3x64x64xf32>) -> tensor<1x3x64x64x!quant.uniform<i8:f32, 1.000000e+00>> %input_dq = stablehlo.uniform_dequantize %input_q : (tensor<1x3x64x64x!quant.uniform<i8:f32, 1.000000e+00>>) -> tensor<1x3x64x64xf32> %weight_0_dq = stablehlo.uniform_dequantize %weight_0_q : (tensor<32x3x3x3x!quant.uniform<i8:f32, 1.000000e+00>>) -> tensor<32x3x3x3xf32> %conv_0 = stablehlo.convolution(%input_dq, %weight_0_dq) {...} : (tensor<1x3x64x64xf32>, tensor<32x3x3x3xf32>) -> tensor<1x32x64x64xf32> %conv_0_q = stablehlo.uniform_quantize %conv_0 : (tensor<1x3x64x64xf32>) -> tensor<1x3x64x64x!quant.uniform<i8:f32, 1.000000e+00>> %conv_0_dq = stablehlo.uniform_dequantize %conv_0_q : (tensor<1x3x64x64x!quant.uniform<i8:f32, 1.000000e+00>>) -> tensor<1x3x64x64xf32> ... } }
so the quantized params could be exported directly to reduce the size of the exported artifacts.
For this question, I'm not sure, I'll leave to StableHLO team member to provide some inputs cc @sdasgup3
Will the QParam + DeQuant QDQ case be supported in the future? So that the exported StableHLO would look like:
Yes, we had plans on achieving the outcome. There may be different paths to achieve the goal, like doing pattern matching at Aten level vs StableHLO level, which we are still exploring.
🐛 Bug
When exporting a pt2e quantized model to StableHLO, I got this error:
As far as I can tell, the QDQ converter patch introduced in #5763 only handles the
Quant
+DeQuant
QDQ pair, but notQParam
+DeQuant
QDQ pair, thus the HLO -> MHLO conversion would fail.As seen in the following pt2e quantized model visualization, the Conv2d weights is a quantized param with
i8
dtype, followed by aDeQuant
OP:To Reproduce
Here is the example code to reproduce the bug:
The XLA log containing the lowered HLO:
Environment