I am struggeling to convert a model when "is_dynamic" is set to False. This should reproduce the issue:
import torch.nn as nn
import torch
import ai_edge_torch
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
from torch._export import capture_pre_autograd_graph
from ai_edge_torch.quantize.pt2e_quantizer import get_symmetric_quantization_config
from ai_edge_torch.quantize.pt2e_quantizer import PT2EQuantizer
from ai_edge_torch.quantize.quant_config import QuantConfig
import math
import torch.nn.functional as F
import os
os.environ["PJRT_DEVICE"]="CPU"
class DummyBlock(nn.Module):
def __init__(self, n: int = 1, m: int = 2) -> None:
super().__init__()
self.m = m
self.n = n
def forward(self, x, x_dec, kernel):
batches, channels, windows, bins = x_dec.shape
n_out = channels // 3
parts = torch.stack([x_dec[:, i * n_out : (i + 1) * n_out] for i in range(3)], dim=-1)
h_real = (parts[:, :, :, :, 0] * 1) + (parts[:, :, :, :, 1] * -0.5) + (parts[:, :, :, :, 2] * -0.5)
h_imag = (parts[:, :, :, :, 1] * math.sqrt(3) / 2) + (parts[:, :, :, :, 2] * -math.sqrt(3) / 2)
batches, channels, windows, bins = x.shape
windows = windows - self.m
bins = bins - (2 * self.n)
x_hat = torch.zeros((batches, channels, windows, bins), dtype=x.dtype, device=x.device)
for t in range(self.m + 1):
for f in range(self.n * 2 + 1):
x_hat += x[:, :, t : t + windows, f : f + bins] * kernel[:, t, f].unsqueeze(1)
return x_hat / 2, h_real, h_imag
def _main():
torch_model = DummyBlock()
batch_size = 1
windows = 4
bins = 2
n = torch_model.n
m = torch_model.m
n_out = (m + 1) * (2 * n + 1)
channels = 3 * n_out
kernel = torch.randn((1, m + 1, 2 * n + 1), dtype=torch.float)
x_dec = torch.randn(batch_size, channels, windows, bins).float()
x = torch.randn(batch_size, channels, windows, bins).float()
sample_args = (x, x_dec, kernel)
torch_model(*sample_args)
torch_model.eval()
pt2e_quantizer = PT2EQuantizer().set_global(
get_symmetric_quantization_config(is_per_channel=True, is_dynamic=False)
)
pt2e_torch_model = capture_pre_autograd_graph(torch_model, sample_args)
pt2e_torch_model = prepare_pt2e(pt2e_torch_model, pt2e_quantizer)
pt2e_torch_model(*sample_args)
pt2e_torch_model = convert_pt2e(pt2e_torch_model, fold_quantize=False)
pt2e_drq_model = ai_edge_torch.convert(
pt2e_torch_model, sample_args, quant_config=QuantConfig(pt2e_quantizer=pt2e_quantizer)
)
pt2e_drq_model.export("dummy_block.tflite")
if __name__ == "__main__":
_main()
I get the following error:
File "../python3.11/site-packages/torch/fx/graph_module.py", line 738, in call_wrapped
return self._wrapped_call(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "../python3.11/site-packages/torch/fx/graph_module.py", line 316, in __call__
raise e
File "../python3.11/site-packages/torch/fx/graph_module.py", line 303, in __call__
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "../python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "../python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<eval_with_key>.8", line 23, in forward
File "../python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "../python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "../python3.11/site-packages/torch/ao/quantization/observer.py", line 1230, in forward
self.reset_histogram(x, min_val, max_val)
File "../python3.11/site-packages/torch/ao/quantization/observer.py", line 1203, in reset_histogram
torch.histc(
RuntimeError: torch.histogram: input tensor and hist tensor should have the same dtype, but got input long int and hist float
Description of the bug:
I am struggeling to convert a model when "is_dynamic" is set to False. This should reproduce the issue:
I get the following error:
How do I resolve this?
Actual vs expected behavior:
No response
Any other information you'd like to share?
No response