tensorflow / flutter-tflite

Apache License 2.0
541 stars 127 forks source link

get_signature_runner(...) feature #81

Open tonieichelkraut opened 1 year ago

tonieichelkraut commented 1 year ago

Hi, I have a model with multiple inputs and outputs. When I run the model in python, I use the signature runner

interpreter = tf.lite.Interpreter(model_path=tflite_file)
interpreter.allocate_tensors()
fn = interpreter.get_signature_runner('serving_default')

This nicely helps me to map my inputs and outputs. In flutter-tflite this feature does not seem to be present (I am new to flutter/dart). When I use the

interpreter.runForMultipleInputs(inputs, outputs)

I need to know the order of the input and output list. When I run

interpreter.getOutputIndex('random_name')

I get

Invalid argument(s): Output error: random_name' is not a valid name for any output. Names of outputs and their indexes are {StatefulPartitionedCall:10: 0, StatefulPartitionedCall:3: 1, StatefulPartitionedCall:0: 2, StatefulPartitionedCall:2: 3, StatefulPartitionedCall:9: 4, StatefulPartitionedCall:8: 5, StatefulPartitionedCall:13: 6, StatefulPartitionedCall:12: 7, StatefulPartitionedCall:5: 8, StatefulPartitionedCall:1: 9, StatefulPartitionedCall:4: 10, StatefulPartitionedCall:11: 11, StatefulPartitionedCall:6: 12, StatefulPartitionedCall:7: 13}

However, the "StatefulPartitionedCall:X" are not what corresponds to my output names in my python code. I don't want to get the mapping correct, just by trying. Any suggestions, how I can solve this with the available feature set or how I access the "signature runner" from dart?

Thank you!

gregorscholz commented 1 year ago

Try using interpreter.getOutputTensors() and interpreter.getInputTensors(). As the name suggests they return a List of Tensors which get expected.

You could call interpreter.getInputTensors().forEach((tensor) => print(tensor)) to get every input tensor, same or output.

Maybe look into this example. Here the model takes a Tensor of the shape [1, 300, 300, 3] and outputs multiple tensors as you can see in line 142.

You can also get a single input or output tensor by just calling interpreter.getInputTensor(int index) or interpreter.getOutputTensor(int index).