GoogleChromeLabs / web-ai-demos

Apache License 2.0
38 stars 10 forks source link

Show a warning about large models being downloaded #65

Open bramus opened 1 month ago

bramus commented 1 month ago

I tried some of the web-ai-demos on https://chrome.dev/, such as https://chrome.dev/web-ai-demos/perf-client-side-gemma-worker/

Some demos say that the model will take about 30s or 1 minute to load. This took longer, as it turned the demo was downloading a model … of more than 1GB … which eventually took 15 minutes to complete.

Screenshot 2024-10-09 at 21 49 25

Please add a warning message as per https://web.dev/articles/client-side-ai-performance#signal_large_downloads guidelines.

From the looks of it, the model doesn’t get cached on disk properly, so people end up downloading the model over and over again.

tomayac commented 1 month ago

The proper way to fix this would be to add an option upstream to MediaPipe so it optionally caches the model. This would have to happen here:

https://github.com/google-ai-edge/mediapipe/blob/59f8ae3637fe2a12c26ea114a190505e85cfce13/mediapipe/tasks/web/core/task_runner.ts#L163-L172


A way to fix this just here for this project is to fetch the model out-of-bounds, store it in the Cache API, and then pass a blob URL to LlmInference.createFromOptions() as I do in this project:

https://github.com/tomayac/mediapipe-llm/blob/62ffef5754d37425625d17a53cdb72f65c8df793/script.js#L357

Agree that a warning would make sense, too.

andreban commented 3 weeks ago

CC @maudnals

andreban commented 3 weeks ago

Mediapipe accepts an modelAssetBuffer instead of a modelAssetPath, and modelAssetBuffer can be an Uint8Array.

So, the code to download the model while checking for progress and the loading the model can look something like this:

  const modelFileName = 'gemma-2b-it-gpu-int4.bin';
  const modelResponse = await fetch(modelFileName);
  const reader = modelResponse.body.getReader();

  const contentLength = response.headers.get('Content-Length');
  let receivedLength = 0;
  let chunks = [];

  while (true) {
      const {done, chunk} = await reader.read();
      if (done) {
          break;
      }

      chunks.push(value);
      receivedLength += value.length;

      console.log(`Received ${receivedLength} of ${contentLength}`)
  }

  let modelData = new Uint8Array(receivedLength);
  let position = 0;
  for(let chunk of chunks) {
      chunksAll.set(chunk, position);
      position += chunk.length;
  }

  const llmInference = await LlmInference.createFromOptions(genaiFileset, {
      baseOptions: { modelAssetBuffer: modelData },
      maxTokens: 6144,
  });