google-ai-edge / ai-edge-torch

Supporting PyTorch models with the Google AI Edge TFLite runtime.
Apache License 2.0
357 stars 50 forks source link

op failed to verify that operand 0 is at most 5-D #263

Closed spacycoder closed 1 month ago

spacycoder commented 1 month ago

Description of the bug:

Hi, I get this error op failed to verify that operand 0 is at most 5-D when I try to convert the following model:

import torch.nn as nn
import torch
import math
import ai_edge_torch

import os
os.environ["PJRT_DEVICE"]="CPU"

class DummyBlock(nn.Module):

    def forward(self, x: torch.Tensor):
        n_out = x.shape[1] // 3
        parts = torch.stack([x[:, i * n_out : (i + 1) * n_out] for i in range(3)], dim=-1).unsqueeze(-1)
        h_real = (parts[..., 0, :] * 1) + (parts[..., 1, :] * -0.5) + (parts[..., 2, :] * -0.5)
        h_imag = (parts[..., 1, :] * math.sqrt(3) / 2) + (parts[..., 2, :] * -math.sqrt(3) / 2)
        return h_real, h_imag

def _main():
    torch_model = DummyBlock()
    x_dec = torch.randn(1, 9, 10, 2).float()
    sample_args = (x_dec,)
    torch_model(*sample_args)

    edge_model = ai_edge_torch.convert(torch_model.eval(), sample_args)
    edge_model.export("dummy_block.tflite")

if __name__ == "__main__":
    _main()

This is the full error:

tensorflow.lite.python.convert_phase.ConverterError: <unknown>:0: error: loc(callsite(callsite(callsite("torch.fx.graph_module.GraphModule.__new__.<locals>.GraphModuleImpl;" at fused["XlaCallModule:", "XlaCallModule@__inference_inner_30"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall@__inference_signature_wrapper_46"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall"])): 'tfl.strided_slice' op failed to verify that operand 0 is at most 5-D
<unknown>:0: note: loc(fused["StatefulPartitionedCall:", "StatefulPartitionedCall"]): called from
I0000 00:00:1727425282.529399  653780 cpu_client.cc:470] TfrtCpuClient destroyed.

Result from running fine_culprits:

import torch
from torch import device
import ai_edge_torch

class CulpritGraphModule(torch.nn.Module):
    def forward(self, arg0_1: "f32[1, 3, 10, 2, 3, 1]"):
        # File: .../test.py:15 in forward, code: h_real = (parts[..., 0, :] * 1) + (parts[..., 1, :] * -0.5) + (parts[..., 2, :] * -0.5)
        select: "f32[1, 3, 10, 2, 1]" = torch.ops.aten.select.int(arg0_1, 4, 2);  arg0_1 = None
        return (select,)

_args = (
    torch.randn((1, 3, 10, 2, 3, 1,), dtype=torch.float32),
)

_edge_model = ai_edge_torch.convert(CulpritGraphModule().eval(), _args)

Actual vs expected behavior:

No response

Any other information you'd like to share?

No response

spacycoder commented 1 month ago

fixed, just had to remove the unsqueeze()