shaqian / tflite-react-native

React Native library for TensorFlow Lite
https://www.npmjs.com/package/tflite-react-native
MIT License
291 stars 106 forks source link

detectObjectOnImage for model with 1 class detects maximum 10 occurences #26

Open ppeelman opened 4 years ago

ppeelman commented 4 years ago

Hi,

Thank you for this library! I am using my own trained object-detection model (using "SSD Inception v1 coco" as a starting point). I have only 1 class, and the goal is to count the number of occurences of that item in a picture.

My issue is the following: running the tflite model returns only a maximum of 10 occurences. Even if I set 'numResultsPerClass' to a number higher than 10, I still get only a maximum of 10 occurences:

My code:

tflite.detectObjectOnImage( { path: imagePath, model: 'SSDMobileNet', imageMean: 127.5, imageStd: 127.5, threshold: 0.5, numResultsPerClass: 60 }, (err: Error, res: any) => { if (err) { reject(err); } else { resolve(res); } } );

Could you please help me find what is causing this? Thanks!

thekundankamal commented 3 years ago

I am getting the same thing how ww can increase it.

ppeelman commented 3 years ago

If I recall correctly, it seemed that the number 10 was hardcoded within the Android code for the module (more specifically in TfliteReactNativeModule.java).

I placed the tflite-react-native module code locally in my project and changed the line (50 for me).

I don't think I had the same issue on iOS.


 @ReactMethod
  private void detectObjectOnImage(final String path, final String model, final float mean, final float std,
                                   final float threshold, final int numResultsPerClass, final ReadableArray ANCHORS,
                                   final int blockSize, final Callback callback) throws IOException {

    ByteBuffer imgData = feedInputTensorImage(path, mean, std);

    if (model.equals("SSDMobileNet")) {
      int NUM_DETECTIONS = 50; <=====================================
      float[][][] outputLocations = new float[1][NUM_DETECTIONS][4];
      float[][] outputClasses = new float[1][NUM_DETECTIONS];
      float[][] outputScores = new float[1][NUM_DETECTIONS];
      float[] numDetections = new float[1];

      Object[] inputArray = {imgData};
      Map<Integer, Object> outputMap = new HashMap<>();