am15h / tflite_flutter_plugin

TensorFlow Lite Flutter Plugin
https://pub.dev/packages/tflite_flutter
Apache License 2.0
504 stars 353 forks source link

Feature Request: Enabled Batch Processing #231

Open saurabhkumar8112 opened 1 year ago

saurabhkumar8112 commented 1 year ago

Right now there is no way to do batch processing. I was trying to run a classification model on a batch of images instead of one.

The standard way to run inference on 1 image is to convert your Image to TensorImage and run the interpreter on it.

TensorImage img = TensorImage.fromImage(image);
interpreter.run(img.buffer, output_buffer.get_buffer());

I have tried reshaping the input tensor size [batch_size,H,W,C]

_interpreter.resizeInputTensor(0, [batch_size, H, W, 3]);

Now if I want to run the interpreter on a batch of say 8, I can't run the interpreter on List cause.

  1. List can't be converted into TensorImage (tflite helper plugin doesn't support it)
  2. If we use an interpreter.runForMultipleInputs(List, output_buffer), then we will get an exception from the interpreter since the var inputTensors = getInputTensors(); shape(say [batch_size, H, W, C]) won't match with the shape of List as the length is 1 for inputTensors. Still, the input length is the list's length (refer to code below, the inputTensors shape will not match with inputs length). This comes from interpreter class
  var inputTensors = getInputTensors();

  for (int i = 0; i < inputs.length; i++) {
    var tensor = inputTensors.elementAt(i);
    final newShape = tensor.getInputShapeIfDifferent(inputs[i]);
    if (newShape != null) {
      resizeInputTensor(i, newShape);
    }
  }

I have looked into StackOverflow and the source code and my understanding is batch_processing isn't available right now. Can we get a feature to enable batch processing?