google-ai-edge / mediapipe

Cross-platform, customizable ML solutions for live and streaming media.
https://ai.google.dev/edge/mediapipe
Apache License 2.0
27.59k stars 5.16k forks source link

Mobile SSD models are expected to have exactly 4 outputs, found 2 #5550

Closed libofei2004 closed 1 month ago

libofei2004 commented 3 months ago

Have I written custom code (as opposed to using a stock example script provided in MediaPipe)

None

OS Platform and Distribution

Ubuntu 22 in wsl2 , android 12

Python Version

3.10

MediaPipe Model Maker version

2.0.4.1

Task name (e.g. Image classification, Gesture recognition etc.)

object detector

Describe the actual behavior

I use mediapipe_model_maker 2.0.4.1 to train an model and use it in an android programme, but it can't run and throws exception.

Describe the expected behaviour

the android programme throws: java.lang.IllegalArgumentException: Error occurred when initializing ObjectDetector: Mobile SSD models are expected to have exactly 4 outputs, found 2

Standalone code/steps you may have used to try to get what you need

1.I trained a tflite model with mediapipe_model_maker.
the code is:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

from mediapipe_model_maker import object_detector

train_dataset_path = '/mnt/d/workspace/imgupload/img/selected1/bt3/train/shot'
validation_dataset_path = '/mnt/d/workspace/imgupload/img/selected1/bt3/train/shot'
cache_dir = '/mnt/d/workspace/imgupload/img/selected1/tmp'

train_data = object_detector.Dataset.from_pascal_voc_folder(
    train_dataset_path,
    cache_dir=cache_dir)

validate_data = object_detector.Dataset.from_pascal_voc_folder(
    validation_dataset_path,
    cache_dir=cache_dir)

hparams = object_detector.HParams(batch_size=8, learning_rate=0.3, epochs=50, export_dir='exported_model')
options = object_detector.ObjectDetectorOptions(
    supported_model=object_detector.SupportedModels.MOBILENET_V2,
    hparams=hparams)

model = object_detector.ObjectDetector.create(
    train_data=train_data,
    validation_data=validate_data,
    options=options)

loss, coco_metrics = model.evaluate(validate_data, batch_size=4)
print(f"Validation loss: {loss}")
print(f"Validation coco metrics: {coco_metrics}")
model.export_model('dogs.tflite')

2.I use the model in an android programme, the code is from:
https://github.com/tensorflow/examples/tree/master/lite/examples/object_detection/android

The Main code is :
package org.tensorflow.lite.examples.detection;

import android.graphics.Bitmap;
import android.graphics.Bitmap.Config;
import android.graphics.Canvas;
import android.graphics.Color;
import android.graphics.Matrix;
import android.graphics.Paint;
import android.graphics.Paint.Style;
import android.graphics.RectF;
import android.graphics.Typeface;
import android.media.ImageReader.OnImageAvailableListener;
import android.os.SystemClock;
import android.util.Size;
import android.util.TypedValue;
import android.widget.Toast;

import com.example.namespace.R;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.tensorflow.lite.examples.detection.customview.OverlayView;
import org.tensorflow.lite.examples.detection.customview.OverlayView.DrawCallback;
import org.tensorflow.lite.examples.detection.env.BorderedText;
import org.tensorflow.lite.examples.detection.env.ImageUtils;
import org.tensorflow.lite.examples.detection.env.Logger;
import org.tensorflow.lite.examples.detection.tflite.Detector;
import org.tensorflow.lite.examples.detection.tflite.TFLiteObjectDetectionAPIModel;
import org.tensorflow.lite.examples.detection.tracking.MultiBoxTracker;

/**
 * An activity that uses a TensorFlowMultiBoxDetector and ObjectTracker to detect and then track
 * objects.
 */
public class DetectorActivity extends CameraActivity implements OnImageAvailableListener {
  private static final Logger LOGGER = new Logger();

  // Configuration values for the prepackaged SSD model.
  private static final int TF_OD_API_INPUT_SIZE = 300;
  private static final boolean TF_OD_API_IS_QUANTIZED = true;
  //private static final String TF_OD_API_MODEL_FILE = "detect.tflite";
  private static final String TF_OD_API_MODEL_FILE = "dogs.tflite";
  private static final String TF_OD_API_LABELS_FILE = "labelmap.txt";
  private static final DetectorMode MODE = DetectorMode.TF_OD_API;
  // Minimum detection confidence to track a detection.
  private static final float MINIMUM_CONFIDENCE_TF_OD_API = 0.5f;
  private static final boolean MAINTAIN_ASPECT = false;
  private static final Size DESIRED_PREVIEW_SIZE = new Size(640, 480);
  private static final boolean SAVE_PREVIEW_BITMAP = false;
  private static final float TEXT_SIZE_DIP = 10;
  OverlayView trackingOverlay;
  private Integer sensorOrientation;

  private Detector detector;

  private long lastProcessingTimeMs;
  private Bitmap rgbFrameBitmap = null;
  private Bitmap croppedBitmap = null;
  private Bitmap cropCopyBitmap = null;

  private boolean computingDetection = false;

  private long timestamp = 0;

  private Matrix frameToCropTransform;
  private Matrix cropToFrameTransform;

  private MultiBoxTracker tracker;

  private BorderedText borderedText;

  @Override
  public void onPreviewSizeChosen(final Size size, final int rotation) {
    final float textSizePx =
        TypedValue.applyDimension(
            TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics());
    borderedText = new BorderedText(textSizePx);
    borderedText.setTypeface(Typeface.MONOSPACE);

    tracker = new MultiBoxTracker(this);

    int cropSize = TF_OD_API_INPUT_SIZE;

    try {
      detector =
          TFLiteObjectDetectionAPIModel.create(
              this,
              TF_OD_API_MODEL_FILE,
              TF_OD_API_LABELS_FILE,
              TF_OD_API_INPUT_SIZE,
              TF_OD_API_IS_QUANTIZED);
      cropSize = TF_OD_API_INPUT_SIZE;
    } catch (final IOException e) {
      e.printStackTrace();
      LOGGER.e(e, "Exception initializing Detector!");
      Toast toast =
          Toast.makeText(
              getApplicationContext(), "Detector could not be initialized", Toast.LENGTH_SHORT);
      toast.show();
      finish();
    }

    previewWidth = size.getWidth();
    previewHeight = size.getHeight();

    sensorOrientation = rotation - getScreenOrientation();
    LOGGER.i("Camera orientation relative to screen canvas: %d", sensorOrientation);

    LOGGER.i("Initializing at size %dx%d", previewWidth, previewHeight);
    rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888);
    croppedBitmap = Bitmap.createBitmap(cropSize, cropSize, Config.ARGB_8888);

    frameToCropTransform =
        ImageUtils.getTransformationMatrix(
            previewWidth, previewHeight,
            cropSize, cropSize,
            sensorOrientation, MAINTAIN_ASPECT);

    cropToFrameTransform = new Matrix();
    frameToCropTransform.invert(cropToFrameTransform);

    trackingOverlay = (OverlayView) findViewById(R.id.tracking_overlay);
    trackingOverlay.addCallback(
        new DrawCallback() {
          @Override
          public void drawCallback(final Canvas canvas) {
            tracker.draw(canvas);
            if (isDebug()) {
              tracker.drawDebug(canvas);
            }
          }
        });

    tracker.setFrameConfiguration(previewWidth, previewHeight, sensorOrientation);
  }

  @Override
  protected void processImage() {
    ++timestamp;
    final long currTimestamp = timestamp;
    trackingOverlay.postInvalidate();

    // No mutex needed as this method is not reentrant.
    if (computingDetection) {
      readyForNextImage();
      return;
    }
    computingDetection = true;
    LOGGER.i("Preparing image " + currTimestamp + " for detection in bg thread.");

    rgbFrameBitmap.setPixels(getRgbBytes(), 0, previewWidth, 0, 0, previewWidth, previewHeight);

    readyForNextImage();

    final Canvas canvas = new Canvas(croppedBitmap);
    canvas.drawBitmap(rgbFrameBitmap, frameToCropTransform, null);
    // For examining the actual TF input.
    if (SAVE_PREVIEW_BITMAP) {
      ImageUtils.saveBitmap(croppedBitmap);
    }

    runInBackground(
        new Runnable() {
          @Override
          public void run() {
            LOGGER.i("Running detection on image " + currTimestamp);
            final long startTime = SystemClock.uptimeMillis();
            final List<Detector.Recognition> results = detector.recognizeImage(croppedBitmap);
            lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime;

            cropCopyBitmap = Bitmap.createBitmap(croppedBitmap);
            final Canvas canvas = new Canvas(cropCopyBitmap);
            final Paint paint = new Paint();
            paint.setColor(Color.RED);
            paint.setStyle(Style.STROKE);
            paint.setStrokeWidth(2.0f);

            float minimumConfidence = MINIMUM_CONFIDENCE_TF_OD_API;
            switch (MODE) {
              case TF_OD_API:
                minimumConfidence = MINIMUM_CONFIDENCE_TF_OD_API;
                break;
            }

            final List<Detector.Recognition> mappedRecognitions =
                new ArrayList<Detector.Recognition>();

            for (final Detector.Recognition result : results) {
              final RectF location = result.getLocation();
              if (location != null && result.getConfidence() >= minimumConfidence) {
                canvas.drawRect(location, paint);

                cropToFrameTransform.mapRect(location);

                result.setLocation(location);
                mappedRecognitions.add(result);
              }
            }

            tracker.trackResults(mappedRecognitions, currTimestamp);
            trackingOverlay.postInvalidate();

            computingDetection = false;

            runOnUiThread(
                new Runnable() {
                  @Override
                  public void run() {
                    showFrameInfo(previewWidth + "x" + previewHeight);
                    showCropInfo(cropCopyBitmap.getWidth() + "x" + cropCopyBitmap.getHeight());
                    showInference(lastProcessingTimeMs + "ms");
                  }
                });
          }
        });
  }

  @Override
  protected int getLayoutId() {
    return R.layout.tfe_od_camera_connection_fragment_tracking;
  }

  @Override
  protected Size getDesiredPreviewFrameSize() {
    return DESIRED_PREVIEW_SIZE;
  }

  // Which detection model to use: by default uses Tensorflow Object Detection API frozen
  // checkpoints.
  private enum DetectorMode {
    TF_OD_API;
  }

  @Override
  protected void setUseNNAPI(final boolean isChecked) {
    runInBackground(
        () -> {
          try {
            detector.setUseNNAPI(isChecked);
          } catch (UnsupportedOperationException e) {
            LOGGER.e(e, "Failed to set \"Use NNAPI\".");
            runOnUiThread(
                () -> {
                  Toast.makeText(this, e.getMessage(), Toast.LENGTH_LONG).show();
                });
          }
        });
  }

  @Override
  protected void setNumThreads(final int numThreads) {
    runInBackground(() -> detector.setNumThreads(numThreads));
  }
}

Other info / Complete Logs

The error log is :
Error getting native address of native library: task_vision_jni
                                                                                                    java.lang.IllegalArgumentException: Error occurred when initializing ObjectDetector: Mobile SSD models are expected to have exactly 4 outputs, found 2
                                                                                                        at org.tensorflow.lite.task.vision.detector.ObjectDetector.initJniWithByteBuffer(Native Method)
                                                                                                        at org.tensorflow.lite.task.vision.detector.ObjectDetector.access$100(ObjectDetector.java:88)
                                                                                                        at org.tensorflow.lite.task.vision.detector.ObjectDetector$3.createHandle(ObjectDetector.java:223)
                                                                                                        at org.tensorflow.lite.task.core.TaskJniUtils.createHandleFromLibrary(TaskJniUtils.java:91)
                                                                                                        at org.tensorflow.lite.task.vision.detector.ObjectDetector.createFromBufferAndOptions(ObjectDetector.java:219)
                                                                                                        at org.tensorflow.lite.examples.detection.tflite.TFLiteObjectDetectionAPIModel.<init>(TFLiteObjectDetectionAPIModel.java:87)
                                                                                                        at org.tensorflow.lite.examples.detection.tflite.TFLiteObjectDetectionAPIModel.create(TFLiteObjectDetectionAPIModel.java:81)
                                                                                                        at org.tensorflow.lite.examples.detection.DetectorActivity.onPreviewSizeChosen(DetectorActivity.java:103)
                                                                                                        at org.tensorflow.lite.examples.detection.CameraActivity$7.onPreviewSizeChosen(CameraActivity.java:448)
                                                                                                        at org.tensorflow.lite.examples.detection.CameraConnectionFragment.setUpCameraOutputs(CameraConnectionFragment.java:360)
                                                                                                        at org.tensorflow.lite.examples.detection.CameraConnectionFragment.openCamera(CameraConnectionFragment.java:365)
                                                                                                        at org.tensorflow.lite.examples.detection.CameraConnectionFragment.-$$Nest$mopenCamera(Unknown Source:0)
                                                                                                        at org.tensorflow.lite.examples.detection.CameraConnectionFragment$3.onSurfaceTextureAvailable(CameraConnectionFragment.java:174)
                                                                                                        at android.view.TextureView.getTextureLayer(TextureView.java:410)
                                                                                                        at android.view.TextureView.draw(TextureView.java:353)
                                                                                                        at android.view.View.updateDisplayListIfDirty(View.java:21885)
                                                                                                        at android.view.View.draw(View.java:22743)
                                                                                                        at android.view.ViewGroup.drawChild(ViewGroup.java:4542)
                                                                                                        at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4300)
                                                                                                        at android.view.View.updateDisplayListIfDirty(View.java:21876)
                                                                                                        at android.view.View.draw(View.java:22743)
                                                                                                        at android.view.ViewGroup.drawChild(ViewGroup.java:4542)
                                                                                                        at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4300)
                                                                                                        at android.view.View.updateDisplayListIfDirty(View.java:21876)
                                                                                                        at android.view.View.draw(View.java:22743)
                                                                                                        at android.view.ViewGroup.drawChild(ViewGroup.java:4542)
                                                                                                        at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4300)
                                                                                                        at android.view.View.draw(View.java:23021)
                                                                                                        at android.view.View.updateDisplayListIfDirty(View.java:21885)
                                                                                                        at android.view.View.draw(View.java:22743)
                                                                                                        at android.view.ViewGroup.drawChild(ViewGroup.java:4542)
                                                                                                        at androidx.coordinatorlayout.widget.CoordinatorLayout.drawChild(CoordinatorLayout.java:1246)
                                                                                                        at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4300)
                                                                                                        at android.view.View.draw(View.java:23021)
                                                                                                        at android.view.View.updateDisplayListIfDirty(View.java:21885)
                                                                                                        at android.view.View.draw(View.java:22743)
                                                                                                        at android.view.ViewGroup.drawChild(ViewGroup.java:4542)
                                                                                                        at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4300)
                                                                                                        at android.view.View.updateDisplayListIfDirty(View.java:21876)
                                                                                                        at android.view.View.draw(View.java:22743)
                                                                                                        at android.view.ViewGroup.drawChild(ViewGroup.java:4542)
                                                                                                        at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4300)
                                                                                                        at android.view.View.updateDisplayListIfDirty(View.java:21876)
                                                                                                        at android.view.View.draw(View.java:22743)
                                                                                                        at android.view.ViewGroup.drawChild(ViewGroup.java:4542)
                                                                                                        at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4300)
                                                                                                        at android.view.View.updateDisplayListIfDirty(View.java:21876)
                                                                                                        at android.view.View.draw(View.java:22743)
                                                                                                        at android.view.ViewGroup.drawChild(ViewGroup.java:4542)
                                                                                                        at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4300)
                                                                                                        at android.view.View.updateDisplayListIfDirty(View.java:21876)
                                                                                                        at android.view.View.draw(View.java:22743)
                                                                                                        at android.view.ViewGroup.drawChild(ViewGroup.java:4542)
                                                                                                        at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4300)
2024-07-31 00:44:09.336 31317-31317 TaskJniUtils            org...lite.examples.objectdetection  E      at android.view.View.draw(View.java:23021)
                                                                                                        at com.android.internal.policy.DecorView.draw(DecorView.java:891)
                                                                                                        at android.view.View.updateDisplayListIfDirty(View.java:21885)
                                                                                                        at android.view.ThreadedRenderer.updateViewTreeDisplayList(ThreadedRenderer.java:534)
                                                                                                        at android.view.ThreadedRenderer.updateRootDisplayList(ThreadedRenderer.java:542)
                                                                                                        at android.view.ThreadedRenderer.draw(ThreadedRenderer.java:625)
                                                                                                        at android.view.ViewRootImpl.draw(ViewRootImpl.java:4657)
                                                                                                        at android.view.ViewRootImpl.performDraw(ViewRootImpl.java:4375)
                                                                                                        at android.view.ViewRootImpl.performTraversals(ViewRootImpl.java:3486)
                                                                                                        at android.view.ViewRootImpl.doTraversal(ViewRootImpl.java:2277)
                                                                                                        at android.view.ViewRootImpl$TraversalRunnable.run(ViewRootImpl.java:9037)
                                                                                                        at android.view.Choreographer$CallbackRecord.run(Choreographer.java:1142)
                                                                                                        at android.view.Choreographer.doCallbacks(Choreographer.java:946)
                                                                                                        at android.view.Choreographer.doFrame(Choreographer.java:875)
                                                                                                        at android.view.Choreographer$FrameDisplayEventReceiver.run(Choreographer.java:1127)
                                                                                                        at android.os.Handler.handleCallback(Handler.java:938)
                                                                                                        at android.os.Handler.dispatchMessage(Handler.java:99)
                                                                                                        at android.os.Looper.loopOnce(Looper.java:210)
                                                                                                        at android.os.Looper.loop(Looper.java:299)
                                                                                                        at android.app.ActivityThread.main(ActivityThread.java:8293)
                                                                                                        at java.lang.reflect.Method.invoke(Native Method)
                                                                                                        at com.android.internal.os.RuntimeInit$MethodAndArgsCaller.run(RuntimeInit.java:556)
                                                                                                        at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:1045)
kuaashish commented 2 months ago

Hi @libofei2004,

Apologies for the delayed response. Could you please let us know if this issue is still ongoing or if it has been resolved on your end?

Thank you!!

github-actions[bot] commented 1 month ago

This issue has been marked stale because it has no recent activity since 7 days. It will be closed if no further activity occurs. Thank you.

github-actions[bot] commented 1 month ago

This issue was closed due to lack of activity after being marked stale for past 7 days.

google-ml-butler[bot] commented 1 month ago

Are you satisfied with the resolution of your issue? Yes No