shaqian / flutter_tflite

Flutter plugin for TensorFlow Lite
https://pub.dartlang.org/packages/tflite
MIT License
632 stars 406 forks source link

Custom model accuracy does not match python benchmark and drops on deeper networks. #56

Open eduardkieser opened 5 years ago

eduardkieser commented 5 years ago

dragon_labels_33.txt Resnet20_0_to_20.tflite.zip data20_mini.zip

I have made a couple of tf model to recognise numbers. After converting to tflite, saving and loading back into python, I benchmark them in python and get classification accuracies of >99%. If I load that same tflite model into a flutter app and benchmark it there on the same data I get accuracies of around 10 - 60 % on the same images. I have noticed that deeper networks seem more affected by this than shallow networks. Shallower networks that benchmark at around 96% in python, get around 75 ish percent in flutter and the Deeper Resnet based model (which I will try to attach somehow) which benchmarks at >99% only gets around 60ish % on the phone, for a 21 class classification (0-20). The model has three one-hot outputs with len=11 (0-9 +nan), with correspond to 100's 10's and 1's.

My python benchmarking code is as follows:

import numpy as np
import tensorflow as tf
from glob import glob
from random import shuffle
from PIL import Image

def reverse_one_hot(one_hot):
    one_hot = one_hot.reshape(3,11)
    categories = [' ','0','1','2','3','4','5','6','7','8','9']
    ix0 = one_hot[0].tolist().index(1)
    ix1 = one_hot[1].tolist().index(1)
    ix2 = one_hot[2].tolist().index(1)
    res = categories[ix0]+categories[ix1]+categories[ix2]
    return res

tf_lite_model_name = 'tflite_models/Resnet20_0_to_20.tflite'

file_list = glob('/data20/*/*.png')
shuffle(file_list)

interpreter = tf.compat.v2.lite.Interpreter(model_path=tf_lite_model_name)
interpreter.allocate_tensors()
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# Test model on random input data.
input_shape = input_details[0]['shape']
print(f'input tensor shape: {input_shape}')

files = file_list
shuffle(files)
incorrect_classes = []
all_classes = []
n_correct = 0
n_incorrect = 0

for ix, file in enumerate(files):

    img = Image.open(file)
    img = img.resize((48, 48))
    np_img = np.array(img, dtype=np.float32).reshape(*input_shape)
    img_mean = np.mean(np_img)
    img_std = np.std(np_img)

    np_img = np_img/255
    interpreter.set_tensor(input_details[0]['index'], np_img)
    interpreter.invoke()
    output_data = interpreter.get_tensor(output_details[0]['index'])
    output_data = (output_data>0.5).astype(int)
    try:
        res = reverse_one_hot(output_data)
    except:
        continue
    lable = file.split('/')[-2].rjust(3)
    if res=='   ':
        res = 'nan'

    if res == lable:
        n_correct = n_correct+1
    else:
        n_incorrect = n_incorrect+1
    acc = n_correct/(n_correct+n_incorrect)
    print(f'\r ix:{ix} acc: {acc}',end='')

The flutter benchmarking is a bit more involved, but at it's core looks as follows:

    for (String imgPath in dataMap.keys) {

      String trueLabel = dataMap[imgPath];

      var recognitions = await Tflite.runModelOnImage(
          path: imgPath, // required
          imageMean: imgMean, // defaults to 117.0
          imageStd: imgStd, // defaults to 1.0
          numResults: 3, // defaults to 5
          threshold: 0.2, // defaults to 0.1
          asynch: true // defaults to true
          );

      if (recognitions.length != 3){
        print('model returned weird length');
        return;
      }

      if (recognitions.length > 0) {
        int p0 = int.parse(recognitions[0]['label']);
        int p1 = int.parse(recognitions[1]['label']);
        int p2 = int.parse(recognitions[2]['label']);
        int intRes = p0+p1+p2;
        String result = intRes.toString();
        // String result = recognitions[0]['label'];
        if (result == trueLabel) {
          countTrue++;
          oneIfTrue = 1;
        } else {
          countFalse++;
          oneIfTrue = 0;
        }
        bloc.addInt(oneIfTrue);
      }

      print('accuracy = ${countTrue / (countTrue + countFalse)}');
      // break;
    }

I have tried various combinations of imageMean and imageStd. I'm a bit out of my depth here, any help would be greatly appreciated. E

eduardkieser commented 5 years ago

Hi there, I has anyone had a chance to look at this yet? Is there any additional info that you need from my end to help simplify the task of resolving this?

mdcroce commented 4 years ago

@eduardkieser did you solve this?

eduardkieser commented 4 years ago

Nope, but I also haven't looked at it in a while.