google-ai-edge / ai-edge-torch

Supporting PyTorch models with the Google AI Edge TFLite runtime.
Apache License 2.0
310 stars 40 forks source link

Converting Torch modules that use the max function #81

Open hbellafkir opened 3 months ago

hbellafkir commented 3 months ago

Description of the bug:

Given a module that calculates the maximum of a tensor, the convert method failed to convert the model appropriately.

import ai_edge_torch
import torch

class Max(torch.nn.Module):
    def forward(self, x):
        return x.max()

sample_input = (torch.randn(8),)
edge_model = ai_edge_torch.convert(Max().eval(), sample_input)

Actual vs expected behavior:

Actual behavior:

ConverterError: Could not translate MLIR to FlatBuffer.<unknown>:0: error: loc(callsite(callsite(callsite("__main__.Max;" at fused["XlaCallModule:", "XlaCallModule@__inference_inner_2449"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall@__inference_signature_wrapper_2455"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall"])): 'vhlo.reduce_v1' op is not part of the vhlo support yet.
<unknown>:0: note: loc(fused["StatefulPartitionedCall:", "StatefulPartitionedCall"]): called from
<unknown>:0: error: failed while converting: 'main': 

Expected behavior: The method successfully converts the model

Any other information you'd like to share?

No response

pkgoogle commented 2 months ago

I was able to replicate, exactly as above on internal Debian.

niemiaszek commented 1 month ago

Is there any update on support for vhlo.reduce_v1? I'm trying to convert a little variation on MobileNet v3 for audio EfficientAT. Such model can be obtained using models.mn.model.get_model() from referenced repository with defaults. This model successfully goes through torch.export for me, but fails on tflite conversion. I could export MobileNet v3 from TorchVision repository without any problem and this EfficientAT implementation is derived from here but is using models.mn.block_types.ConcurrentSEBlock with max aggregation.

ConverterError: Could not translate MLIR to FlatBuffer.<unknown>:0: error: loc(callsite(callsite(callsite("models.mn.model.MN/models.mn.block_types.InvertedResidual_4/torch.nn.modules.container.Sequential_block/models.mn.block_types.ConcurrentSEBlock_2;" at fused["XlaCallModule:", "XlaCallModule@__inference_inner_29013"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall@__inference_signature_wrapper_29553"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall"])): 'vhlo.reduce_v1' op is not part of the vhlo support yet.
<unknown>:0: note: loc(fused["StatefulPartitionedCall:", "StatefulPartitionedCall"]): called from
<unknown>:0: note: loc(callsite(callsite(callsite("models.mn.model.MN/models.mn.block_types.InvertedResidual_4/torch.nn.modules.container.Sequential_block/models.mn.block_types.ConcurrentSEBlock_2;" at fused["XlaCallModule:", "XlaCallModule@__inference_inner_29013"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall@__inference_signature_wrapper_29553"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall"])): see current operation: 
%128 = "vhlo.reduce_v1"(%127, %99) <{dimensions = #vhlo.tensor_v1<dense<0> : tensor<1xi64>>}> ({
^bb0(%arg15: tensor<f32>, %arg16: tensor<f32>):
  %280 = "tfl.maximum"(%arg15, %arg16) : (tensor<f32>, tensor<f32>) -> tensor<f32>
  "vhlo.return_v1"(%280) : (tensor<f32>) -> ()
}) : (tensor<1x1x72x8x8xf32>, tensor<f32>) -> tensor<1x72x8x8xf32>

Sorry for not concrete help on debugging this, but I've just started digging into your work and learning the innnards. I'm very excited about this set of tools, thank you :)

pkgoogle commented 3 weeks ago

There is a current PR for this https://github.com/tensorflow/tensorflow/pull/73635.