google-ai-edge / ai-edge-torch

Supporting PyTorch models with the Google AI Edge TFLite runtime.
Apache License 2.0
277 stars 36 forks source link

Replace SELECT with SELECT_V2 #93

Open pjpratik opened 1 month ago

pjpratik commented 1 month ago

Description of the bug:

The converter uses SELECT op in TFLite model and it needs to be replaced with SELECTV2 since the TFLite Micro uses SELECTV2. The generated TFLite model cannot be used in Micro setting because of that. Reproduce:

import torch
import ai_edge_torch

m = torch.nn.PReLU()
input = (torch.randn(2),)
edge_model = ai_edge_torch.convert(m.eval(), input)
edge_model.export("prelu.tflite")

and then visualize using model-explorer gives model_explorer_graph

Actual vs expected behavior:

TFLite Micro outputs

error: 'class tflite::MicroMutableOpResolver<16>' has no member named 'AddSelect'; did you mean 'AddSelectV2'?
   73 |   micro_op_resolver.AddSelect();
      |                     ^~~~~~~~~
      |                     AddSelectV2

Expected: To be compatible with TFLite Micro ops or Prelu can be directly used as buitin op is already available.

Any other information you'd like to share?

No response

pkgoogle commented 1 month ago

Hey @pjpratik, seems reasonable, will need a micro device to test. Also, I'm wondering if we'll need a separate setting for micro conversions.

pjpratik commented 1 month ago

Hi @pkgoogle , also I was wondering if PReLu can be directly used a node instead of breaking down into greater, mul, select ops.

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/builtin_ops.h#L84

Thanks.

pkgoogle commented 1 month ago

@pjpratik I believe that would be more ideal... I looked into it and the deconstruction appears to happen here: https://github.com/google-ai-edge/ai-edge-torch/blob/main/ai_edge_torch/convert/conversion_utils.py#L133 You may want to check to see if PyTorch can export to StableHLO w/o the deconstruction: https://github.com/pytorch/xla/blob/master/torch_xla/stablehlo.py#L321