pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.49k stars 483 forks source link

int8 StableHLO export #8373

Open Wheest opened 1 week ago

Wheest commented 1 week ago

🐛 Bug

I'm looking at generating a int8 quantised PyTorch model (both weights and activations at int8), and exporting to StableHLO via torch-xla's exported_program_to_stablehlo.

Right now I'm relatively ambivalent regarding how the model is quantised, as long as I end up with a valid graph with int8 weights and activations (with i32 accumulation types, presumably).

However, there are a few ways to quantise in PyTorch, with various caveats and issues. The furthest I've been able to get is below, in a reproducible script. However, it raises the error:

  File "/app/examples/generate_weenet_mlp_int8.py", line 70, in <module>
    stablehlo_program = exported_program_to_stablehlo(exported)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/.local/lib/python3.11/site-packages/torch_xla/stablehlo.py", line 618, in exported_program_to_stablehlo
    bundle = _exported_program_to_stablehlo_bundle(exported_model, options)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/.local/lib/python3.11/site-packages/torch_xla/stablehlo.py", line 370, in _exported_program_to_stablehlo_bundle
    raise RuntimeError(message)
RuntimeError: This model contains ops not capturable by Pytorch/XLA: aten::_fused_moving_avg_obs_fq_helper

To Reproduce

import os
import torch
import torch.nn as nn
from torch.ao.quantization import get_default_qat_qconfig, QConfigMapping
from torch.ao.quantization.quantize_fx import prepare_fx
from torch.utils.data import DataLoader, TensorDataset
from torch.export import export
from torch_xla.stablehlo import exported_program_to_stablehlo

# Ensure CPU-only execution for torch_xla and disable CUDA
os.environ["XLA_USE_BF16"] = "0"
os.environ["XLA_USE_CUDA"] = "0"
os.environ["CUDA_VISIBLE_DEVICES"] = ""

# Define a simple neural network model
class WeeNetMLP(nn.Module):
    def __init__(self):
        super(WeeNetMLP, self).__init__()
        self.flatten = nn.Flatten()
        self.dense1 = nn.Linear(4 * 6, 32)
        self.relu1 = nn.ReLU()
        self.dense2 = nn.Linear(32, 16)
        self.relu2 = nn.ReLU()
        self.dense3 = nn.Linear(16, 8)
        self.relu3 = nn.ReLU()

    def forward(self, x):
        x = self.flatten(x)
        x = self.dense1(x)
        x = self.relu1(x)
        x = self.dense2(x)
        x = self.relu2(x)
        x = self.dense3(x)
        x = self.relu3(x)
        return x

# Initialize the model and set to evaluation mode
model = WeeNetMLP().eval()

# Configure fake quantization using QAT configuration
qconfig_mapping = QConfigMapping().set_global(get_default_qat_qconfig("fbgemm"))

# Define example inputs and input shape
example_inputs = (torch.randn(1, 4, 6),)
input_shape = (1, 4, 6)

# Prepare the model for fake quantization
prepared_model = prepare_fx(model, qconfig_mapping, example_inputs=example_inputs)

# Generate random data for calibration
calibration_data = torch.randn(100, *input_shape)

# Create a dataset and data loader for calibration
calibration_dataset = TensorDataset(calibration_data)
calibration_data_loader = DataLoader(calibration_dataset, batch_size=10)

# Calibrate the model with the calibration data
for data in calibration_data_loader:
    prepared_model(data[0])  # Run the prepared model on each batch of calibration data

# After calibration
prepared_model.apply(torch.ao.quantization.disable_observer)

# Export the prepared model to StableHLO format
exported = export(prepared_model, example_inputs)
stablehlo_program = exported_program_to_stablehlo(exported)

Expected behavior

I would expect this to produce a StableHLO graph with int8 tensors in it.

If this can be achieved with a different quantisation method in PyTorch, that also works. The issue here seems to be around this aten op.

Environment

JackCaoG commented 1 week ago

@lsy323 is out but he can take a look when he is back.

miladm commented 1 week ago

More context, we are looking to expand torch_ao support in the coming future; appreciate you filing the bug and surfacing use cases and issues observed. @lsy323 to help drive this issue as mentioned earlier.

Wheest commented 16 hours ago

Thanks! I was also looking at using torch_ao for a larger model, though experienced https://github.com/pytorch/pytorch/issues/140943, where the model didn't get past torch.export.