am15h / tflite_flutter_helper

TensorFlow Lite Flutter Helper Library
https://pub.dev/packages/tflite_flutter_helper
Apache License 2.0
160 stars 295 forks source link

High latency of Image Processor due to NormalizeOp #22

Open leeflix opened 3 years ago

leeflix commented 3 years ago

Any ideas why the latency of applying an image processor to a tensor image varies so drastically?

runs: 100
avg: 532ms
min: 24ms
max: 7566ms

plt

Code to reproduce:

import 'dart:typed_data';

import 'package:flutter/services.dart';
import 'package:flutter/widgets.dart' hide Image;
import 'package:image/image.dart';
import 'package:tflite_flutter_helper/tflite_flutter_helper.dart';

main() async {
  WidgetsFlutterBinding.ensureInitialized();

  Image image = decodeImage(
    (await rootBundle.load("assets/images/test.JPG")).buffer.asUint8List(),
  );

  ImageProcessor imageProcessor = ImageProcessorBuilder()
      .add(ResizeOp(300, 300, ResizeMethod.BILINEAR))
      .add(NormalizeOp(127.5, 127.5))
      .build();

  int min;
  int max;
  int total = 0;
  int runs = 100;

  for (int i = 0; i < runs; i++) {
    TensorImage tensorImage = TensorImage.fromImage(image);
    int t1 = DateTime.now().millisecondsSinceEpoch;
    ByteBuffer input = imageProcessor.process(tensorImage).buffer;
    int t2 = DateTime.now().millisecondsSinceEpoch;
    int dt = t2 - t1;
    total += dt;
    if (min == null || dt < min) min = dt;
    if (max == null || dt > max) max = dt;
  }

  print("runs: $runs");
  print("avg: ${total ~/ runs}ms");
  print("min: ${min}ms");
  print("max: ${max}ms");
}
leeflix commented 3 years ago

I found out that the reason for this behavious is the normalization in the image processor. If I remove .add(NormalizeOp(127.5, 127.5)) I get the following results:

runs: 100
avg: 51ms
min: 30ms
max: 148ms

comparison

Does anyone have an explanation for this behaviour?

am15h commented 3 years ago

Thanks for the analysis @FelixBruebach.

NormalizeOp is an expensive operation. I don't think it can be optimized any further, therefore if you can use a quantized model rather than float, you should be able to achieve minimum possible latency.

I don't have very accurate comments on the nature of the graphs you have shared, but as per the benchmarking tests I have conducted (on object_detection_flutter) earlier, a float model (using normalizeop) would at max take 2X time than the quantized counterpart when using the application.

leeflix commented 3 years ago

The graphs plot the deltatimes (needed for processing an image) that are calculated in the loop (see code).

I made a fork of object_detection_flutter and replaced the model with a trained model (ssd_mobiledet_cpu_coco) from the Model Zoo 1 of the TensorFlow Object Detection API and added the NormalizeOp.

You could pull the fork to observe the behaviour yourself or watch the video I also added to the fork where I record the issue. In the video you can see that sometime the pre-processing time spikes over 2 seconds which is a factor of ~100.

leeflix commented 3 years ago

Could you share the project you conducted the benchmark on? Maybe I could find my error that way. @am15h

leeflix commented 3 years ago

I finally get a really good latency by replacing the ImageProcessor with a NormalizeOp with the following code to normalize the input:

  static Uint8List imageToByteListFloat32(
    Image image,
    double mean,
    double std,
  ) {
    var convertedBytes = Float32List(1 * image.height * image.width * 3);
    var buffer = Float32List.view(convertedBytes.buffer);
    int pixelIndex = 0;
    for (var i = 0; i < image.height; i++) {
      for (var j = 0; j < image.width; j++) {
        var pixel = image.getPixel(j, i);
        buffer[pixelIndex++] = (getRed(pixel) - mean) / std;
        buffer[pixelIndex++] = (getGreen(pixel) - mean) / std;
        buffer[pixelIndex++] = (getBlue(pixel) - mean) / std;
      }
    }
    return convertedBytes.buffer.asUint8List();
  }

I found this code in the example of the other TFLite Plugin for Flutter "Tflite" (I changed the code a little though).

Feel free to close this issue, but I still think that there is something odd about the implementation of NormalizeOp. But sadly I have no clue as to what the reason could be.

am15h commented 3 years ago

Thanks @FelixBruebach. Glad that you are able to achieve better results. I will need to investigate the differences and check if this procedure can be made compatible with the current image processor architecture.

xunkai55 commented 3 years ago

We've also seen people complained about the performance of NormalizeOp in TFLite Support. Weirdly though, I remembered that I didn't reproduce the bad performance.

Thanks for the information. I will take a deeper look.

AntoineChauviere commented 1 year ago

@leeflix please can you provide the entire code in which you replace NormalizeOp by this function for better efficiency, because I don't understand when you resize your image and normalize?

leeflix commented 1 year ago

@AntoineChauviere This is a class I wrote in order to use EfficientDet Lite:



import 'dart:typed_data';

import 'package:flutter/services.dart';
import 'package:image/image.dart';
import 'package:quiver/iterables.dart';
import 'package:tflite_flutter/tflite_flutter.dart';

import '../object_detection/data/bounding_box.dart';
import '../object_detection/data/detection.dart';

class EfficientDetLite {
  final Interpreter _interpreter;
  final List<String> _labelMap;
  final int _inputWidth;
  final int _inputHeight;
  final Interpolation _interpolation = Interpolation.linear;

  EfficientDetLite(
    this._interpreter,
    this._labelMap,
  )   : _inputWidth = _interpreter.getInputTensors()[0].shape[2],
        _inputHeight = _interpreter.getInputTensors()[0].shape[1];

  static Future<EfficientDetLite> fromAssets({
    required interpreterAssetName,
    required labelMapAssetName,
    threads,
  }) async {
    if (threads == null) {
      if (Platform.isAndroid) {
        threads = 4;
      } else if (Platform.isIOS) {
        threads = 2;
      } else {
        threads = 1;
      }
    }

    Interpreter interpreter = await Interpreter.fromAsset(
      interpreterAssetName,
      options: InterpreterOptions()..threads = threads,
    );

    List<String> labelMap =
        (await rootBundle.loadString('assets/$labelMapAssetName'))
            .split("${Platform.isAndroid ? "\r" : ""}\n");

    return EfficientDetLite(interpreter, labelMap);
  }

  List<Detection> detect(Image inputImage) {
    if (inputImage.width - inputImage.height != 0) throw Error();

    Image image = copyResize(
      inputImage,
      width: _inputWidth,
      height: _inputHeight,
      interpolation: _interpolation,
    );

    var input = _imageToUint8List(image);
    var output0 = List.filled(1 * 25, 0).reshape([1, 25]);
    var output1 = List.filled(1 * 25 * 4, 0).reshape([1, 25, 4]);
    var output2 = List.filled(1, 0).reshape([1]);
    var output3 = List.filled(1 * 25, 0).reshape([1, 25]);

    _interpreter.runForMultipleInputs(
      [input],
      {0: output0, 1: output1, 2: output2, 3: output3},
    );

    List<Detection> detections =
        zip<dynamic>([output1[0], output0[0], output3[0]])
            .map(
              (x) => Detection(
                _labelMap[x[2].toInt()],
                x[1],
                BoundingBox(
                  xmin: (x[0][1] * inputImage.width).round(),
                  ymin: (x[0][0] * inputImage.height).round(),
                  xmax: (x[0][3] * inputImage.width).round(),
                  ymax: (x[0][2] * inputImage.height).round(),
                ),
              ),
            )
            .toList();

    return detections;
  }

  static Uint8List _imageToUint8List(Image image) {
    var bytes = Uint8List(1 * image.height * image.width * 3);
    var buffer = Uint8List.view(bytes.buffer);
    int i = 0;
    for (var y = 0; y < image.height; y++) {
      for (var x = 0; x < image.width; x++) {
        int pixel = image.getPixel(x, y);
        buffer[i++] = getRed(pixel);
        buffer[i++] = getGreen(pixel);
        buffer[i++] = getBlue(pixel);
      }
    }
    return bytes.buffer.asUint8List();
  }
}