pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.45k stars 454 forks source link

Failed to lower QParam + DeQuant QDQ pair to StableHLO #6567

Open Nullkooland opened 6 months ago

Nullkooland commented 6 months ago

🐛 Bug

When exporting a pt2e quantized model to StableHLO, I got this error:

error: 'mhlo.uniform_dequantize' op operand #0 must be tensor of 4/8/16/32-bit uniform quantized signed integer or 4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<32x3x3x3xi8>'

As far as I can tell, the QDQ converter patch introduced in #5763 only handles the Quant + DeQuant QDQ pair, but not QParam + DeQuant QDQ pair, thus the HLO -> MHLO conversion would fail.

As seen in the following pt2e quantized model visualization, the Conv2d weights is a quantized param with i8 dtype, followed by a DeQuant OP:

model_qdq

To Reproduce

Here is the example code to reproduce the bug:

import os

import torch
import torch.export
from torch import nn

from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantize_pt2e import (
    prepare_pt2e,
    convert_pt2e,
)
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
    XNNPACKQuantizer,
    get_symmetric_quantization_config,
)
from torch.fx.passes.graph_drawer import FxGraphDrawer

from torch_xla import stablehlo
import torch_xla.core.xla_model as xm

class TestModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv_0 = nn.Conv2d(
            in_channels=3,
            out_channels=32,
            kernel_size=(3, 3),
            padding=(1, 1),
            bias=False
        )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        x = self.conv_0(input)
        out = torch.sigmoid(x)
        return out

if __name__ == "__main__":
    device = xm.xla_device()

    model = TestModel().eval()
    input_tensor = torch.randn(size=(1, 3, 64, 64), dtype=torch.float32)
    sample_inputs = (input_tensor, )

    model = capture_pre_autograd_graph(model, sample_inputs)

    # Insert PTQ observers.
    quant_config = get_symmetric_quantization_config()
    quantizer = XNNPACKQuantizer().set_global(quant_config)

    model_prepared = prepare_pt2e(model, quantizer)

    # Do calibration.

    # Quantize the model.
    model_quant = convert_pt2e(model_prepared, fold_quantize=True)
    model_quant_exported = torch.export.export(model_quant, sample_inputs)

    # Visualize the quantized model.
    model_quant_exported.graph.print_tabular()
    drawer = FxGraphDrawer(model_quant_exported, "model_qdq")
    with open(f"{drawer._name}.svg", mode="wb") as f:
        f.write(drawer.get_dot_graph().create_svg())

    # Convert to stablehlo.
    print(f"[================ PID: {os.getpid()} ================]")
    model_stablehlo = stablehlo.exported_program_to_stablehlo(
        model_quant_exported
    )
    model_stablehlo_str = model_stablehlo.get_stablehlo_text()
    print(model_stablehlo_str)

The XLA log containing the lowered HLO:

Execution Analysis: ================================================================================
2024-02-19 16:44:53.339801: I torch_xla/csrc/runtime/pjrt_computation_client.cc:550] Executing PjRt computation on CPU:0
2024-02-19 16:44:53.339833: I external/xla/xla/pjrt/cpu/cpu_client.cc:1641] ExecuteShard executes computation SyncTensorsGraph.4 on assigned replica/partition on device TFRT_CPU_0
2024-02-19 16:44:53.340033: I torch_xla/csrc/runtime/pjrt_computation_client.cc:605] Returning 2 results
loc("custom-call.2"): error: 'mhlo.uniform_dequantize' op operand #0 must be tensor of 4/8/16/32-bit uniform quantized signed integer or 4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<32x3x3x3xi8>'
2024-02-19 16:45:01.140535: I torch_xla/csrc/runtime/tf_logging.cc:12] Check failed: status.ok() 
*** Begin stack trace ***
        tsl::CurrentStackTrace()
        torch_xla::ConvertHloToStableHlo(xla::HloModuleProto const*, mlir::ModuleOp*)
        torch_xla::hloToStablehlo(xla::HloModuleProto const*, bool)
        torch_xla::DumpUtil::ToHlo(c10::ArrayRef<torch::lazy::Value>, torch::lazy::BackendDevice const&, torch_xla::EmitMode)
        torch_xla::XLAGraphExecutor::DumpHloComputation(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > > const&, torch_xla::EmitMode)
...
        __libc_start_main
        _start
*** End stack trace ***
HLO -> MHLO conversion failed.
MHLO Module from HLO -> MHLO conversion is not legal.Please open a github issue to PyTorch/XLA.
Original HLO dump:
HloModule IrToHlo.11, entry_computation_layout={(s8[32,3,3,3]{3,2,1,0}, f32[1,3,64,64]{3,2,1,0})->(f32[1,32,64,64]{3,2,1,0})}

ENTRY %IrToHlo.11 (p0.1: s8[32,3,3,3], p1.3: f32[1,3,64,64]) -> (f32[1,32,64,64]) {
  %p1.3 = f32[1,3,64,64]{3,2,1,0} parameter(1)
  %custom-call.4 = s8[1,3,64,64]{3,2,1,0} custom-call(f32[1,3,64,64]{3,2,1,0} %p1.3), custom_call_target="stablehlo.uniform_quantize", api_version=API_VERSION_TYPED_FFI, backend_config={scale=[1.00],zero_point=[0],storage_type=si8,expressed_type=f32,storage_min=-128,storage_max=127}
  %custom-call.5 = f32[1,3,64,64]{3,2,1,0} custom-call(s8[1,3,64,64]{3,2,1,0} %custom-call.4), custom_call_target="stablehlo.uniform_dequantize", api_version=API_VERSION_TYPED_FFI, backend_config={scale=[1.00],zero_point=[0],storage_type=si8,expressed_type=f32,storage_min=-128,storage_max=127}
  %p0.1 = s8[32,3,3,3]{3,2,1,0} parameter(0)
  %custom-call.2 = f32[32,3,3,3]{3,2,1,0} custom-call(s8[32,3,3,3]{3,2,1,0} %p0.1), custom_call_target="stablehlo.uniform_dequantize", api_version=API_VERSION_TYPED_FFI, backend_config={scale=[1.00],zero_point=[0],storage_type=si8,expressed_type=f32,storage_min=-127,storage_max=127}
  %convolution.6 = f32[1,32,64,64]{3,2,1,0} convolution(f32[1,3,64,64]{3,2,1,0} %custom-call.5, f32[32,3,3,3]{3,2,1,0} %custom-call.2), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_oi01->bf01
  %custom-call.7 = s8[1,32,64,64]{3,2,1,0} custom-call(f32[1,32,64,64]{3,2,1,0} %convolution.6), custom_call_target="stablehlo.uniform_quantize", api_version=API_VERSION_TYPED_FFI, backend_config={scale=[1.00],zero_point=[0],storage_type=si8,expressed_type=f32,storage_min=-128,storage_max=127}
  %custom-call.8 = f32[1,32,64,64]{3,2,1,0} custom-call(s8[1,32,64,64]{3,2,1,0} %custom-call.7), custom_call_target="stablehlo.uniform_dequantize", api_version=API_VERSION_TYPED_FFI, backend_config={scale=[1.00],zero_point=[0],storage_type=si8,expressed_type=f32,storage_min=-128,storage_max=127}
  %logistic.9 = f32[1,32,64,64]{3,2,1,0} logistic(f32[1,32,64,64]{3,2,1,0} %custom-call.8)
  ROOT %tuple.10 = (f32[1,32,64,64]{3,2,1,0}) tuple(f32[1,32,64,64]{3,2,1,0} %logistic.9)
}

Environment

miladm commented 6 months ago

Thanks for raising this issue @Nullkooland. @lsy323 can you please have a look?

lsy323 commented 6 months ago

Hi @Nullkooland, thank you for reporting the issue!

Upstream introduced a BC breaking change, in which the fp->quant pair will be folded by default. As you mentioned Qparam + DeQuant is not supported now, so you'll need to set fold_quantize=False. The repro script passed on my end after disabling the quant weight folding.

Also please don't forget to set STABLEHLO_BYTECODE_FROM_PRETTYPRINT=1 (As mentioned in https://github.com/pytorch/xla/pull/5763), until the related StableHLO issue is resolved.

Nullkooland commented 6 months ago

Hi @Nullkooland, thank you for reporting the issue!

Upstream introduced a BC breaking change, in which the fp->quant pair will be folded by default. As you mentioned Qparam + DeQuant is not supported now, so you'll need to set fold_quantize=False. The repro script passed on my end after disabling the quant weight folding.

Also please don't forget to set STABLEHLO_BYTECODE_FROM_PRETTYPRINT=1 (As mentioned in #5763), until the related StableHLO issue is resolved.

@lsy323 Thanks for your reply!

Will the QParam + DeQuant QDQ case be supported in the future? So that the exported StableHLO would look like:

module @ExampleQDQModel {
    func.func @main(
        %weight_0_q: tensor<32x3x3x3x!quant.uniform<i8:f32, 1.000000e+00>>,
        // Folded QParams.
        %input: tensor<1x3x64x64xf32>
    ) {
        %input_q = stablehlo.uniform_quantize %input : (tensor<1x3x64x64xf32>) -> tensor<1x3x64x64x!quant.uniform<i8:f32, 1.000000e+00>>
        %input_dq = stablehlo.uniform_dequantize %input_q : (tensor<1x3x64x64x!quant.uniform<i8:f32, 1.000000e+00>>) -> tensor<1x3x64x64xf32>
        %weight_0_dq = stablehlo.uniform_dequantize %weight_0_q : (tensor<32x3x3x3x!quant.uniform<i8:f32, 1.000000e+00>>) -> tensor<32x3x3x3xf32>
        %conv_0 = stablehlo.convolution(%input_dq, %weight_0_dq) {...} : (tensor<1x3x64x64xf32>, tensor<32x3x3x3xf32>) -> tensor<1x32x64x64xf32>
        %conv_0_q = stablehlo.uniform_quantize %conv_0 : (tensor<1x3x64x64xf32>) -> tensor<1x3x64x64x!quant.uniform<i8:f32, 1.000000e+00>>
        %conv_0_dq = stablehlo.uniform_dequantize %conv_0_q : (tensor<1x3x64x64x!quant.uniform<i8:f32, 1.000000e+00>>) -> tensor<1x3x64x64xf32>
        ...
    }
}

so the quantized params could be exported directly to reduce the size of the exported artifacts.

lsy323 commented 5 months ago

Hi @Nullkooland, thank you for reporting the issue!

Upstream introduced a BC breaking change, in which the fp->quant pair will be folded by default. As you mentioned Qparam + DeQuant is not supported now, so you'll need to set fold_quantize=False. The repro script passed on my end after disabling the quant weight folding.

Also please don't forget to set STABLEHLO_BYTECODE_FROM_PRETTYPRINT=1 (As mentioned in #5763), until the related StableHLO issue is resolved.

Hi @Nullkooland, just FYI STABLEHLO_BYTECODE_FROM_PRETTYPRINT=1 is not needed anymore now.

lsy323 commented 5 months ago

Hi @Nullkooland, thank you for reporting the issue! Upstream introduced a BC breaking change, in which the fp->quant pair will be folded by default. As you mentioned Qparam + DeQuant is not supported now, so you'll need to set fold_quantize=False. The repro script passed on my end after disabling the quant weight folding. Also please don't forget to set STABLEHLO_BYTECODE_FROM_PRETTYPRINT=1 (As mentioned in #5763), until the related StableHLO issue is resolved.

@lsy323 Thanks for your reply!

Will the QParam + DeQuant QDQ case be supported in the future? So that the exported StableHLO would look like:

module @ExampleQDQModel {
    func.func @main(
        %weight_0_q: tensor<32x3x3x3x!quant.uniform<i8:f32, 1.000000e+00>>,
        // Folded QParams.
        %input: tensor<1x3x64x64xf32>
    ) {
        %input_q = stablehlo.uniform_quantize %input : (tensor<1x3x64x64xf32>) -> tensor<1x3x64x64x!quant.uniform<i8:f32, 1.000000e+00>>
        %input_dq = stablehlo.uniform_dequantize %input_q : (tensor<1x3x64x64x!quant.uniform<i8:f32, 1.000000e+00>>) -> tensor<1x3x64x64xf32>
        %weight_0_dq = stablehlo.uniform_dequantize %weight_0_q : (tensor<32x3x3x3x!quant.uniform<i8:f32, 1.000000e+00>>) -> tensor<32x3x3x3xf32>
        %conv_0 = stablehlo.convolution(%input_dq, %weight_0_dq) {...} : (tensor<1x3x64x64xf32>, tensor<32x3x3x3xf32>) -> tensor<1x32x64x64xf32>
        %conv_0_q = stablehlo.uniform_quantize %conv_0 : (tensor<1x3x64x64xf32>) -> tensor<1x3x64x64x!quant.uniform<i8:f32, 1.000000e+00>>
        %conv_0_dq = stablehlo.uniform_dequantize %conv_0_q : (tensor<1x3x64x64x!quant.uniform<i8:f32, 1.000000e+00>>) -> tensor<1x3x64x64xf32>
        ...
    }
}

so the quantized params could be exported directly to reduce the size of the exported artifacts.

For this question, I'm not sure, I'll leave to StableHLO team member to provide some inputs cc @sdasgup3

sdasgup3 commented 4 months ago

Will the QParam + DeQuant QDQ case be supported in the future? So that the exported StableHLO would look like:

Yes, we had plans on achieving the outcome. There may be different paths to achieve the goal, like doing pattern matching at Aten level vs StableHLO level, which we are still exploring.