Hi, I get this error op failed to verify that operand 0 is at most 5-D when I try to convert the following model:
import torch.nn as nn
import torch
import math
import ai_edge_torch
import os
os.environ["PJRT_DEVICE"]="CPU"
class DummyBlock(nn.Module):
def forward(self, x: torch.Tensor):
n_out = x.shape[1] // 3
parts = torch.stack([x[:, i * n_out : (i + 1) * n_out] for i in range(3)], dim=-1).unsqueeze(-1)
h_real = (parts[..., 0, :] * 1) + (parts[..., 1, :] * -0.5) + (parts[..., 2, :] * -0.5)
h_imag = (parts[..., 1, :] * math.sqrt(3) / 2) + (parts[..., 2, :] * -math.sqrt(3) / 2)
return h_real, h_imag
def _main():
torch_model = DummyBlock()
x_dec = torch.randn(1, 9, 10, 2).float()
sample_args = (x_dec,)
torch_model(*sample_args)
edge_model = ai_edge_torch.convert(torch_model.eval(), sample_args)
edge_model.export("dummy_block.tflite")
if __name__ == "__main__":
_main()
This is the full error:
tensorflow.lite.python.convert_phase.ConverterError: <unknown>:0: error: loc(callsite(callsite(callsite("torch.fx.graph_module.GraphModule.__new__.<locals>.GraphModuleImpl;" at fused["XlaCallModule:", "XlaCallModule@__inference_inner_30"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall@__inference_signature_wrapper_46"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall"])): 'tfl.strided_slice' op failed to verify that operand 0 is at most 5-D
<unknown>:0: note: loc(fused["StatefulPartitionedCall:", "StatefulPartitionedCall"]): called from
I0000 00:00:1727425282.529399 653780 cpu_client.cc:470] TfrtCpuClient destroyed.
Description of the bug:
Hi, I get this error
op failed to verify that operand 0 is at most 5-D
when I try to convert the following model:This is the full error:
Result from running fine_culprits:
Actual vs expected behavior:
No response
Any other information you'd like to share?
No response