google-ai-edge / ai-edge-torch

Supporting PyTorch models with the Google AI Edge TFLite runtime.
Apache License 2.0
378 stars 51 forks source link

Problems with quantization #264

Open spacycoder opened 2 months ago

spacycoder commented 2 months ago

Description of the bug:

I am struggeling to convert a model when "is_dynamic" is set to False. This should reproduce the issue:

import torch.nn as nn
import torch
import ai_edge_torch

from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
from torch._export import capture_pre_autograd_graph

from ai_edge_torch.quantize.pt2e_quantizer import get_symmetric_quantization_config
from ai_edge_torch.quantize.pt2e_quantizer import PT2EQuantizer
from ai_edge_torch.quantize.quant_config import QuantConfig
import math
import torch.nn.functional as F
import os
os.environ["PJRT_DEVICE"]="CPU"

class DummyBlock(nn.Module):

    def __init__(self, n: int = 1, m: int = 2) -> None:
        super().__init__()
        self.m = m
        self.n = n

    def forward(self, x, x_dec, kernel):
        batches, channels, windows, bins = x_dec.shape
        n_out = channels // 3
        parts = torch.stack([x_dec[:, i * n_out : (i + 1) * n_out] for i in range(3)], dim=-1)
        h_real = (parts[:, :, :, :, 0] * 1) + (parts[:, :, :, :, 1] * -0.5) + (parts[:, :, :, :, 2] * -0.5)
        h_imag = (parts[:, :, :, :, 1] * math.sqrt(3) / 2) + (parts[:, :, :, :, 2] * -math.sqrt(3) / 2)

        batches, channels, windows, bins = x.shape

        windows = windows - self.m
        bins = bins - (2 * self.n)

        x_hat = torch.zeros((batches, channels, windows, bins), dtype=x.dtype, device=x.device)
        for t in range(self.m + 1):
            for f in range(self.n * 2 + 1):
                x_hat += x[:, :, t : t + windows, f : f + bins] * kernel[:, t, f].unsqueeze(1)

        return x_hat / 2, h_real, h_imag

def _main():
    torch_model = DummyBlock()
    batch_size = 1
    windows = 4
    bins = 2

    n = torch_model.n
    m = torch_model.m
    n_out = (m + 1) * (2 * n + 1)
    channels = 3 * n_out

    kernel = torch.randn((1, m + 1, 2 * n + 1), dtype=torch.float)
    x_dec = torch.randn(batch_size, channels, windows, bins).float()
    x = torch.randn(batch_size, channels, windows, bins).float()

    sample_args = (x, x_dec, kernel)
    torch_model(*sample_args)
    torch_model.eval()

    pt2e_quantizer = PT2EQuantizer().set_global(
        get_symmetric_quantization_config(is_per_channel=True, is_dynamic=False)
    )

    pt2e_torch_model = capture_pre_autograd_graph(torch_model, sample_args)
    pt2e_torch_model = prepare_pt2e(pt2e_torch_model, pt2e_quantizer)
    pt2e_torch_model(*sample_args)
    pt2e_torch_model = convert_pt2e(pt2e_torch_model, fold_quantize=False)

    pt2e_drq_model = ai_edge_torch.convert(
        pt2e_torch_model, sample_args, quant_config=QuantConfig(pt2e_quantizer=pt2e_quantizer)
    )

    pt2e_drq_model.export("dummy_block.tflite")

if __name__ == "__main__":
    _main()

I get the following error:


  File "../python3.11/site-packages/torch/fx/graph_module.py", line 738, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "../python3.11/site-packages/torch/fx/graph_module.py", line 316, in __call__
    raise e
  File "../python3.11/site-packages/torch/fx/graph_module.py", line 303, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "../python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "../python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<eval_with_key>.8", line 23, in forward
  File "../python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "../python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "../python3.11/site-packages/torch/ao/quantization/observer.py", line 1230, in forward
    self.reset_histogram(x, min_val, max_val)
  File "../python3.11/site-packages/torch/ao/quantization/observer.py", line 1203, in reset_histogram
    torch.histc(
RuntimeError: torch.histogram: input tensor and hist tensor should have the same dtype, but got input long int and hist float

How do I resolve this?

Actual vs expected behavior:

No response

Any other information you'd like to share?

No response

pkgoogle commented 2 months ago

I was able to replicate this from the latest stable release