google-ai-edge / mediapipe

Cross-platform, customizable ML solutions for live and streaming media.
https://mediapipe.dev
Apache License 2.0
26.78k stars 5.09k forks source link

Customized .tflite Renet50 model for Object Classification on Web does not work #5537

Open RubensZimbres opened 1 month ago

RubensZimbres commented 1 month ago

Have I written custom code (as opposed to using a stock example script provided in MediaPipe)

No

OS Platform and Distribution

Ubuntu 22.04

Python Version

3.10

MediaPipe Model Maker version

I didn't use Modelmaker, I used a PyTorch Resnet model converted with ai-edge-torch

Task name (e.g. Image classification, Gesture recognition etc.)

Image classification

Describe the actual behavior

The tutorial at codepen works for the tflite model Efficientnet, but not the model customized with ai-edge-torch

Describe the expected behaviour

As the HTML code works for the supported model tflite Efficientnet, it was supposed to work also with the customized tflite model, given that the customized model successfully loads at the MediaPipe Studio web interface at https://mediapipe-studio.webapps.google.com/home, but not in my HTML page.

Standalone code/steps you may have used to try to get what you need

The VSCode debugger shows an error.

Could not read source map for https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision: Unexpected 404 response from https://cdn.jsdelivr.net/npm/@mediapipe/vision_bundle_mjs.js.map: Failed to resolve the requested file.
Uncaught TypeError TypeError: Failed to fetch
    at l (cdn.jsdelivr.net/npm/@mediapipe/tasks-vision:7:47151)
    at l (cdn.jsdelivr.net/npm/@mediapipe/tasks-vision:7:76591)
    at o (cdn.jsdelivr.net/npm/@mediapipe/tasks-vision:7:121629)
    at Zo (cdn.jsdelivr.net/npm/@mediapipe/tasks-vision:7:45744)

Chrome code inspection does show these errors:

Refused to execute script from 'https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision/vision_bundle.cjs' because its MIME type ('application/node') is not executable, and strict MIME type checking is enabled.
ESSE__MEDIAPIPE.html:1 Access to fetch at 'https://storage.googleapis.com/xxxxxxxxxx/resnet50_quantized.tflite' from origin 'null' has been blocked by CORS policy: No 'Access-Control-Allow-Origin' header is present on the requested resource. If an opaque response serves your needs, set the request's mode to 'no-cors' to fetch the resource with CORS disabled.
storage.googleapis.com/xxxxxxxxxxx/resnet50_quantized.tflite:1 

       Failed to load resource: net::ERR_FAILED
tasks-vision:7 

       Uncaught (in promise) TypeError: Failed to fetch
    at Mh.l (tasks-vision:7:47151)
    at Mh.l (tasks-vision:7:76591)
    at Mh.o (tasks-vision:7:121629)
    at Zo (tasks-vision:7:45744)
    at async createImageClassifier (ESSE__MEDIAPIPE.html:194:21)

Here's my code:

<!DOCTYPE html>
<html lang="en">
<head>

  <script src="https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision/vision_bundle.cjs"
    crossorigin="anonymous"></script>
    <style>

/* Copyright 2023 The MediaPipe Authors.
https://ai.google.dev/edge/mediapipe/solutions/vision/image_classifier/web_js

Error: INVALID_ARGUMENT: Classification tflite models are assumed to have a single subgraph.; Initialize was not ok; StartGraph failed

Error: NOT_FOUND: Input tensor has type float32: it requires specifying NormalizationOptions metadata to preprocess input images.; Initialize was not ok; StartGraph failed

Validate tflite: https://netron.app/

 */

@use "@material";
body {
  font-family: roboto;
  margin: 2em;
  color: #3d3d3d;
  --mdc-theme-primary: #007f8b;
  --mdc-theme-on-primary: #f1f3f4;
}

h1 {
  color: #007f8b;
}

h2 {
  clear: both;
}

video {
  clear: both;
  display: block;
}

section {
  opacity: 1;
  transition: opacity 500ms ease-in-out;
}

.mdc-button.mdc-button--raised.removed {
  display: none;
}

.removed {
  display: none;
}

.invisible {
  opacity: 0.2; 
}

.videoView,
.classifyOnClick {
  position: relative;
  float: left;
  width: 48%;
  margin: 2% 1%;
  cursor: pointer;
}

.videoView p,
.classifyOnClick p {
  padding: 5px;
  background-color: #007f8b;
  color: #fff;
  z-index: 2;
  margin: 0;
}

.highlighter {
  background: rgba(0, 255, 0, 0.25);
  border: 1px dashed #fff;
  z-index: 1;
  position: absolute;
}

.classifyOnClick {
  z-index: 0;
  font-size: calc(8px + 1.2vw);
}

.classifyOnClick img {
  width: 100%;
}

.webcamPredictions {
  padding-top: 5px;
  padding-bottom: 5px;
  background-color: #007f8b;
  color: #fff;
  border: 1px dashed rgba(255, 255, 255, 0.7);
  z-index: 2;
  margin: 0;
  width: 100%;
  font-size: calc(8px + 1.2vw);
}

    </style>
</head>
<body>

  <link href="https://unpkg.com/material-components-web@latest/dist/material-components-web.min.css" rel="stylesheet">
  <script src="https://unpkg.com/material-components-web@latest/dist/material-components-web.min.js"></script>

<!-- Copyright 2023 The MediaPipe Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

     http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. -->

<h1>Classifying images using the MediaPipe Image Classifier Task</h1>

<section id="demos" class="invisible">
  <h2>Demo: Classify Images</h2>
  <p><b>Click on an image below</b> to see its classification.</p>
  <div class="classifyOnClick">
    <img src="https://assets.codepen.io/9177687/dog_flickr_publicdomain.jpeg" width="100%" crossorigin="anonymous" title="Click to get classification!" />
    <p class="classification removed">
    </p>
  </div>
  <div class="classifyOnClick">
    <img src="https://assets.codepen.io/9177687/cat_flickr_publicdomain.jpeg" width="100%" crossorigin="anonymous" title="Click to get classification!" />
    <p class="classification removed">
    </p>
  </div>

  <h2>Demo: Webcam continuous classification</h2>
  <p>Hold some objects up close to your webcam to get real-time classification. For best results, avoid having too many objects visible to the camera.</br>Click <b>enable webcam</b> below and grant access to the webcam if prompted.</p>

  <div class="webcam">
    <button id="webcamButton" class="mdc-button mdc-button--raised">
      <span class="mdc-button__ripple"></span>
      <span class="mdc-button__label">ENABLE WEBCAM</span>
    </button>
    <video id="webcam" autoplay playsinline></video>
    <p id="webcamPredictions" class="webcamPredictions removed"></p>
  </div>
</section>

<script type="module">
import {
  ImageClassifier,
  FilesetResolver
} from "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision";

// Get DOM elements
const video = document.getElementById("webcam");
const webcamPredictions = document.getElementById("webcamPredictions");
const demosSection = document.getElementById("demos") ;
let enableWebcamButton;
let webcamRunning = false;
const videoHeight = "360px";
const videoWidth = "480px";

const imageContainers = document.getElementsByClassName(
  "classifyOnClick"
);
let runningMode = "IMAGE";

// Add click event listeners for the img elements.
for (let i = 0; i < imageContainers.length; i++) {
  imageContainers[i].children[0].addEventListener("click", handleClick);
}

// Track imageClassifier object and load status.
let imageClassifier;

/**
 * Create an ImageClassifier from the given options.
 * You can replace the model with a custom one.
 */
const createImageClassifier = async () => {
  const vision = await FilesetResolver.forVisionTasks(
    "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision/wasm"
  );
  imageClassifier = await ImageClassifier.createFromOptions(vision, {
    baseOptions: {
      modelAssetPath: `https://storage.googleapis.com/xxxxxxxxxxxxx/resnet50_quantized.tflite`
      // NOTE: For this demo, we keep the default CPU delegate.
      // working one https://storage.googleapis.com/mediapipe-models/image_classifier/efficientnet_lite0/float32/1/efficientnet_lite0.tflite
      // https://storage.googleapis.com/xxxxxxxxxxxxxxx/resnet50_quantized.tflite
    },
    maxResults: 3,
    runningMode: runningMode
  });

  // Show demo section now model is ready to use.
  demosSection.classList.remove("invisible");
};
createImageClassifier();

/**
 * Demo 1: Classify images on click and display results.
 */
async function handleClick(event) {
  // Do not classify if imageClassifier hasn't loaded
  if (imageClassifier === undefined) {
    return;
  }
  // if video mode is initialized, set runningMode to image
  if (runningMode === "VIDEO") {
    runningMode = "IMAGE";
    await imageClassifier.setOptions({ runningMode: "IMAGE" });
  }

  // imageClassifier.classify() returns a promise which, when resolved, is a ClassificationResult object.
  // Use the ClassificationResult to print out the results of the prediction.
  const classificationResult = imageClassifier.classify(event.target);
  // Write the predictions to a new paragraph element and add it to the DOM.
  const classifications = classificationResult.classifications;

  const p = event.target.parentNode.childNodes[3];
  p.className = "classification";
  p.innerText =
    "Classificaton: " +
    classifications[0].categories[0].categoryName +
    "\n Confidence: " +
    Math.round(parseFloat(classifications[0].categories[0].score) * 100) +
    "%";
  classificationResult.close();
}

/********************************************************************
// Demo 2: Continuously grab image from webcam stream and classify it.
********************************************************************/

// Check if webcam access is supported.
function hasGetUserMedia() {
  return !!(navigator.mediaDevices && navigator.mediaDevices.getUserMedia);
}

// Get classification from the webcam
async function predictWebcam() {
  // Do not classify if imageClassifier hasn't loaded
  if (imageClassifier === undefined) {
    return;
  }
  // if image mode is initialized, create a new classifier with video runningMode
  if (runningMode === "IMAGE") {
    runningMode = "VIDEO";
    await imageClassifier.setOptions({ runningMode: "VIDEO" });
  }
  const startTimeMs = performance.now();
  const classificationResult = imageClassifier.classifyForVideo(
      video,
      startTimeMs
    );
  video.style.height = videoHeight;
  video.style.width = videoWidth;
  webcamPredictions.style.width = videoWidth;
  const classifications = classificationResult.classifications;
  webcamPredictions.className = "webcamPredictions";
  webcamPredictions.innerText =
    "Classification: " +
    classifications[0].categories[0].categoryName +
    "\n Confidence: " +
    Math.round(parseFloat(classifications[0].categories[0].score) * 100) +
    "%";
  // Call this function again to keep predicting when the browser is ready.
  if (webcamRunning === true) {
    window.requestAnimationFrame(predictWebcam);
  }
}

// Enable the live webcam view and start classification.
async function enableCam(event) {
  if (imageClassifier === undefined) {
    return;
  }

  if (webcamRunning === true) {
    webcamRunning = false;
    enableWebcamButton.innerText = "ENABLE PREDICTIONS";
  } else {
    webcamRunning = true;
    enableWebcamButton.innerText = "DISABLE PREDICTIONS";
  }

  // getUsermedia parameters.
  const constraints = {
    video: true
  };

  // Activate the webcam stream.
  video.srcObject = await navigator.mediaDevices.getUserMedia(constraints);
  video.addEventListener("loadeddata", predictWebcam);
}

// If webcam supported, add event listener to button.
if (hasGetUserMedia()) {
  enableWebcamButton = document.getElementById("webcamButton");
  enableWebcamButton.addEventListener("click", enableCam);
} else {
  console.warn("getUserMedia() is not supported by your browser");
}

</script>

</body>
</html>
RubensZimbres commented 1 month ago

I solved the problem:

First, create a cors_file.json:

[
  {
    "origin": ["https://your-website.com"],
    "method": ["GET", "POST"],
    "responseHeader": ["Content-Type", "Authorization"],
    "maxAgeSeconds": 86400
  }
]

Then:

gcloud storage buckets update gs://your-bucket-with-.tflite --cors-file=cors_file.json

However, my inference time is 5 seconds, much more than the milliseconds of the default model, EfficientNet.

How can I speed up inference? It looks like it's an incompatibility version between tflite-support and Tensorflow, that does not optimize the saved tflite model.

kuaashish commented 1 month ago

Hi @tyrmullen,

Do you have any suggestions for speeding up inference using the customized model instead of the default one in our Web Task API? Any advice would be greatly appreciated.

Thank you!!

RubensZimbres commented 1 month ago

@kuaashish I noticed that if you quantize the tflite model, a problem in signature shows up, then inference time goes to 2 seconds. If you do not quantize the model, inference time is 170 milliseconds.