flutter-ml / google_ml_kit_flutter

A flutter plugin that implements Google's standalone ML Kit
MIT License
912 stars 709 forks source link

Cannot convert imglib.Image to InputImage: the resulting image is not what it should be #536

Closed andynewman10 closed 9 months ago

andynewman10 commented 9 months ago

I made an interesting experiment in which

This test is interesting in that it allows developers to verify that a generated InputImage instance is valid. In other words, it allows to study/debug Image-to-InputImage conversion routines easily.

My question is: how to successfully create an InputImage from an imagelib Image? I have been trying all bits of code found on the web for weeks, to no avail.

This test is Android only for now.

Steps to reproduce the behavior:

  1. Create a pass-through TF Lite custom model using the following code
import numpy as np
import tensorflow as tf
import os

model = tf.keras.models.Sequential([
    tf.keras.Input(shape=(32, 32, 3)),
    tf.keras.layers.Flatten()
])

converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
open('testnetµ.tflite', "wb").write(tflite_model)

Add model metadata using "passthrough" parameters : 0-mean, 1-std, numclasses=32x32x3 flattened=3072. To add metadata, I use metadata_writer_for_image_classifier.py, provided by the Tensorflow team.

  1. Create an ML Kit image labeler this way:
    LocalLabelerOptions customImageLabelerOptions = LocalLabelerOptions(
        confidenceThreshold: 0, modelPath: modelPath!, maxCount: 1000000);

    imageLabeler = ImageLabeler(options: customImageLabelerOptions);

I want all logits to be passed through, I can therefore set maxCount to 3072 (=32x32x3, flattened) or any higher value (1000000). Similarly, confidenceThreshold: 0 is meant to include all values.

  1. Perform an inference using the following code:
import 'package:image/image.dart' as imglib;

    bool readFromDisk = true;
    if (readFromDisk) {
      // Here, I am using a 32x32 JPG image, completely red
      inputImage = InputImage.fromFilePath(...);
    }
    else {  
      // Build a red image by myself and do not use
      // InputImage.fromFilePath
      var im = imglib.Image(width: 32, height: 32);
      for (int yy = 0; yy < 32; yy++) {
        for (int xx = 0; xx < 32; xx++) {
          im.setPixelRgba(xx, yy, 255, 0, 0, 255);
        }
      }
      inputImage = convertImage(im); // see below
    }

    final Future<List<ImageLabel>> pendingTask = imageLabeler!.processImage(inputImage);

    pendingTask.then((List<ImageLabel> imageLabels) {
       // Look at imageLabels with the debugger HERE (see Step 4 below)
    });

  InputImage convertImage(imglib.Image image) {
    final WriteBuffer allBytes = WriteBuffer();

    // Slow code, just for testing purposes
    for (int y = 0; y < image.height; y++) {
      for (int x = 0; x < image.width; x++) {
        final p = image.getPixel(x, y);
        int r = p.r.toInt();
        int g = p.g.toInt();
        int b = p.b.toInt();
        allBytes.putUint8(b);
        allBytes.putUint8(g);
        allBytes.putUint8(r);
        allBytes.putUint8(0xFF);
      }
    }
    final bytes = allBytes.done().buffer.asUint8List();

    final metadata = InputImageMetadata(
        // Using bgra8888 here on Android is expected to be supported, isn't it?
        // (CameraControllers have a limitation whereby yuv420 is the only supported value on Android,
        // note this is a different story since we are not using the Camera package here...)
        format: InputImageFormat.bgra8888,
        size: Size(image.width.toDouble(), image.height.toDouble()),
        rotation: InputImageRotation.rotation0deg,
        bytesPerRow: 32*4); // each pixel has 4 bytes, easy.

    return InputImage.fromBytes(bytes: bytes, metadata: metadata);
  }
  1. Place a breakpoint in the then handler in the code above and inspect the values of imageLabels.

Expected behavior

Whether readFromDisk is true or false, I should get the same results. More specifically, I should get

Actual behavior

When readFromDisk is true I get the expected results.

When readFromDisk is false, I get:

Additional testing

I rewrote the convertImage function so that an InputImage with yuv420 encoding is used: the results are also wrong.

Logits (label confidence values) are in this case: [239.0, 198.0, 61.0, 15.0, 9.0, 0.0]. Again, they should be [255.0] only (as in the InputImage.fromFilePath case, which shows that reading an image from disk works fine).

  Uint8List colorconvertRGB_IYUV_I420(imglib.Image image) {
    Uint8List aRGB = image.data!.toUint8List();
    int width = image.width;
    int height = image.height;
    int frameSize = width * height;
    int chromasize = (frameSize / 4).toInt();

    int yIndex = 0;
    int uIndex = frameSize;
    int vIndex = frameSize + chromasize;
    Uint8List yuv = Uint8List.fromList(
        List<int>.filled((width * height * 3 / 2).toInt(), 0));

    int a, R, G, B, Y, U, V;
    int index = 0;
    for (int j = 0; j < height; j++) {
      for (int i = 0; i < width; i++) {
        //a = (aRGB[index] & 0xff000000) >> 24; //not using it right now
        R = (aRGB[index] & 0xff0000) >> 16;
        G = (aRGB[index] & 0xff00) >> 8;
        B = (aRGB[index] & 0xff) >> 0;

        Y = ((66 * R + 129 * G + 25 * B + 128) >> 8) + 16;
        U = ((-38 * R - 74 * G + 112 * B + 128) >> 8) + 128;
        V = ((112 * R - 94 * G - 18 * B + 128) >> 8) + 128;

        yuv[yIndex++] = ((Y < 0) ? 0 : ((Y > 255) ? 255 : Y));

        if (j % 2 == 0 && index % 2 == 0) {
          yuv[vIndex++] = ((U < 0) ? 0 : ((U > 255) ? 255 : U));
          yuv[uIndex++] = ((V < 0) ? 0 : ((V > 255) ? 255 : V));
        }
        index++;
      }
    }
    return yuv;
  }

  InputImage imgLibImageToInputImage(imglib.Image image) {
    final bytes = colorconvertRGB_IYUV_I420(image);

    final metadata = InputImageMetadata(
        format: InputImageFormat.yuv420,
        size: Size(image.width.toDouble(), image.height.toDouble()),
        rotation: InputImageRotation.rotation0deg,
        bytesPerRow: 32);

    return InputImage.fromBytes(bytes: bytes, metadata: metadata);
  }

Platform (please complete the following information):

google_mlkit_image_labeling: ^0.9.0
image: ^4.0.17
andynewman10 commented 9 months ago

I just pushed the .tflite model to use for the repro (that way, it is not necessary to run the python script to create the model):

https://github.com/andynewman10/testrepo/blob/main/testnet.tflite

andynewman10 commented 9 months ago

I did manage to read the Android code performing the inference, but I had to decompile the AAR - somehow it's very difficult to find the code on github sometimes.

Anyway, as @fbernaly mentionned (thank you!), the ML Kit code dealing with InputImage instances can be found here (for Android):

https://github.com/flutter-ml/google_ml_kit_flutter/blob/master/packages/google_mlkit_image_labeling/android/src/main/java/com/google_mlkit_image_labeling/ImageLabelDetector.java

Right away I see that the only supported format for the Flutter package is nv21. The native Android version supports yv12 and yuv_420_888, too. That's important to know!

The code that I pasted above (colorconvertRGB_IYUV_I420) uses yuv420 and has no hope to work.

Looking at the code I also discovered that the minimal image size is 32x32 (which is the size I am using, phew...)

So I went ahead with an RGB to NV21 converter, using the following code:

  static Uint8List encodeYUV420SP(imglib.Image image) {
    Uint8List argb = image.data!.toUint8List();
    int width = image.width;
    int height = image.height;

    int ySize = width * height;
    int uvSize = width * height * 2;
    var yuv420sp = List<int>.filled((width * height * 3) ~/ 2, 0);

    final int frameSize = width * height;
    int yIndex = 0;
    int uvIndex = frameSize;

    int a, R, G, B, Y, U, V;
    int index = 0;
    for (int j = 0; j < height; j++) {
      for (int i = 0; i < width; i++) {
        a = (argb[index] & 0xff000000) >> 24; // a is not used obviously
        R = (argb[index] & 0xff0000) >> 16;
        G = (argb[index] & 0xff00) >> 8;
        B = (argb[index] & 0xff) >> 0;

        // well known RGB to YUV algorithm
        Y = ((66 * R + 129 * G + 25 * B + 128) >> 8) + 16;
        U = ((-38 * R - 74 * G + 112 * B + 128) >> 8) + 128;
        V = ((112 * R - 94 * G - 18 * B + 128) >> 8) + 128;

        /* NV21 has a plane of Y and interleaved planes of VU each sampled by a factor of 2                 
        meaning for every 4 Y pixels there are 1 V and 1 U.
        Note the sampling is every otherpixel AND every other scanline.*/
        yuv420sp[yIndex++] = ((Y < 0) ? 0 : ((Y > 255) ? 255 : Y));
        if (j % 2 == 0 && index % 2 == 0) {
          yuv420sp[uvIndex++] = ((V < 0) ? 0 : ((V > 255) ? 255 : V));
          yuv420sp[uvIndex++] = ((U < 0) ? 0 : ((U > 255) ? 255 : U));
        }
        index++;
      }
    }

    return Uint8List.fromList(yuv420sp);
  }

  InputImage imgLibImageToInputImage(imglib.Image image) {
    final bytes = encodeYUV420SP(image);

    final metadata = InputImageMetadata(
        format: InputImageFormat.nv21,
        size: Size(image.width.toDouble(), image.height.toDouble()),
        rotation: InputImageRotation.rotation0deg,
        bytesPerRow: 0); // ignored

    return InputImage.fromBytes(bytes: bytes, metadata: metadata);
  }

This is some code I found on the web, and it looks pretty good to me, respecting the NV21 encoding where Y appears first then V and U are stored in interleaved form.

And it still doesn't work: InputImage.fromFilePath does work, and the same image, generated with my code, doesn't.

andynewman10 commented 9 months ago

Following my previous message, things are now working as expected, so I am closing this issue. Things indeed work when nv21 is used to generate the InputImage. I mistakenly believed I had to use yuv420 because a limitation in Flutter, but that limitation was on the camera package side (camera could not handle nv21 until recently), not on the ML Kit side.

fbernaly commented 9 months ago

Great, actually that is specify in the README, you need to use nv21 when using the camera plugin.

https://github.com/flutter-ml/google_ml_kit_flutter/tree/master/packages/google_mlkit_commons#creating-an-inputimage