google-ai-edge / mediapipe

Cross-platform, customizable ML solutions for live and streaming media.
https://mediapipe.dev
Apache License 2.0
26.78k stars 5.09k forks source link

Repeated tokens generated when input a long prompt for GPU models during Android LLM inference #5553

Open neilsun2009 opened 1 month ago

neilsun2009 commented 1 month ago

Have I written custom code (as opposed to using a stock example script provided in MediaPipe)

None

OS Platform and Distribution

Android 14

Mobile device if the issue happens on mobile device

Xiaomi 14 Pro

Browser and version if the issue happens on browser

No response

Programming Language and version

Kotlin

MediaPipe version

0.10.14

Bazel version

No response

Solution

LLM Inference

Android Studio, NDK, SDK versions (if issue is related to building in Android environment)

SDK 34

Xcode & Tulsi version (if issue is related to building for iOS)

No response

Describe the actual behavior

Generate repeated tokens if prompt is too long

Describe the expected behaviour

Generate correctly

Standalone code/steps you may have used to try to get what you need

  1. Deploy the example app
  2. Input a long prompt, e.g. the following text copied from https://huggingface.co/runwayml/stable-diffusion-v1-5 :
Stable Diffusion is a latent text-to-image diffusion model capable of generating photo-realistic images given any text input. For more information about how Stable Diffusion functions, please have a look at 🤗's Stable Diffusion blog.

The Stable-Diffusion-v1-5 checkpoint was initialized with the weights of the Stable-Diffusion-v1-2 checkpoint and subsequently fine-tuned on 595k steps at resolution 512x512 on "laion-aesthetics v2 5+" and 10% dropping of the text-conditioning to improve classifier-free guidance sampling.

You can use this both with the 🧨Diffusers library and the RunwayML GitHub repository.
  1. Observe the generated result

Other info / Complete Logs

I'm building an RAG-based mobile app using MediaPipe LLM Inference API, with the gemma 1.1 2b int8 gpu checkpoint downloaded directly from Kaggle.

Based on my experiment, the solution works fine with a short prompt, say if I'm only retrieving 1 chunk of document, and clip it to 100 chars as in the following prompt:

You are an intelligent search engine. You will be provided with some retrieved context, as well as the users query. Your job is to understand the request, and answer based on the retrieved context. 
Here is the retrieved context 
--------------------------------------------------                                                                                                     
+ you're mentioning is called image net and uh I began building it um 2007 and spent the the next thre

Here is the user's query: what is imagenet.
Your answer:

The output is fine:

ImageNet is a massive dataset of over 14 million images divided into 1000 categories, allowing researchers and engineers to train deep learning models for various tasks, such as image classification, object detection, and semantic segmentation.

But if i keep the whole piece of retrieved text chunk, as in this prompt:

You are an intelligent search engine. You will be provided with some retrieved context, as well as the users query. Your job is to understand the request, and answer based on the retrieved context. 
Here is the retrieved context 
--------------------------------------------------  
+ you're mentioning is called image net and uh I began building it um 2007 and spent the the next three years pretty much with my graduate students building it and you asked me was there a problem building it where do I even begin um even at the conception of this project I was told that it it really was a bad idea I was a young assistant professor I remember it was my first year actually as assistant professor at Princeton and uh for example a very very um uh respected mentor of mine in the field

Here is the user's query: what is imagenet.
Your answer:

The model output is something like this, and it won't end until max token is reached:

GRANTED Walkover Walkover Walkover Walkover Walkover Walkover Walkover Walkover Walkover Walkover Walkover Walkover Walkover Walkover Walkover Walkover Walkover Walkover Walkover Walkover Walkover Walkover Walkover Walkover Walkover Walkover Walkover Walkover Walkover Walkover Walkover Walkover Walkover Walkover Walkover Walkover Walkover Walkover Walkover Walkover

As Gemma 1.1 2B allows a context window of 8k tokens, the 2nd prompt should not be a problem. Also, CPU version of the same model works fine even with 5 retrieved chunks. So I assume the problem lies somewhere in the GPU execution process in MediaPipe LLM Inference.

kuaashish commented 1 month ago

Hi @neilsun2009,

Could you please confirm if you are testing this on a physical Xiaomi 14 Pro device or an emulator? This information will help us better understand the issue.

Thank you!!

neilsun2009 commented 1 month ago

Hi @kuaashish , it's tested on a physical device.

kuaashish commented 3 days ago

Hi @neilsun2009,

We apologize for the delay in response. Could you please test the latest version available 0.10.15 and let us know if the issue still persists.

Thank you!!