googleapis / python-aiplatform

A Python SDK for Vertex AI, a fully managed, end-to-end platform for data science and machine learning.
Apache License 2.0
581 stars 319 forks source link

GenerativeModel response hang on multithreaded (linux) #3365

Open jk1333 opened 4 months ago

jk1333 commented 4 months ago

GenerativeModel reponse hang on multireaded case. Works fine on windows. Also added workaround.

Environment details

Steps to reproduce

  1. run attached sample on linux. (windows has no issues)
  2. expected response is as below Waiting to be completed Hello world! I am a large language model, trained by Google. Hello world! I am a large language model, trained by Google. Hello world! I am a large language model, trained by Google. Hello world! I am a large language model, trained by Google. Hello world! I am a large language model, trained by Google.

Code example

# example
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import as_completed

def analyze_text_gemini_pro(input):
    from vertexai.generative_models import GenerativeModel, HarmCategory, HarmBlockThreshold
    def get_model():
        return GenerativeModel("gemini-1.0-pro-001")
    responses = get_model().generate_content(
        input,
        generation_config={
            "candidate_count": 1,
            "max_output_tokens": 8192,
            "temperature": 0,
            "top_p": 0.5,
            "top_k": 1
        },
        safety_settings={
            HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
            HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
        },
        stream=True
    )
    #print(f"=== Number of Tokens === \n{response._raw_response.usage_metadata}\n===")
    #print(response)
    #return responses                             # Workaround return here
    txt = ""
    for response in responses:                # Hang here on linux
        txt += response.text
    return txt

async_jobs = []

for i in range(0, 5):
    async_jobs.append((analyze_text_gemini_pro, ("Hello world",), i))

futures = {ThreadPoolExecutor(8).submit(func, *param): decor for func, param, decor in async_jobs}

print("Waiting to be completed")
for idx, future in enumerate(as_completed(futures), start=1):
    responses = future.result()
    print(responses)
    #for response in responses:              # Workaround
    #    print(response.text, end="")

Stack trace

N/A

Making sure to follow these steps will guarantee the quickest resolution possible.

Thanks!

Ark-kun commented 3 months ago

This is a weird issue. There are several changes that make your sample code work in my tests:

Update: The issue stopped reproducing for me. In any case the issue cannot be in Vertex SDK. It might be the gRPC library or maybe just the model server getting issue when many huge requests are sent at the same time.

Please try using strace, ltrace or gdb or otehr debugging tool to get the exact location where the code is stuck.

AniketModi commented 2 months ago

I am also facing the same issue. Trying to hit the api of gemini-pro-vision using python SDK. When I am trying to hit parallel requests using multi threading approach, it is getting hang.

 try:
        json_account_info = json.loads(api_key, strict=False)
        credentials = service_account.Credentials.from_service_account_info(
            json_account_info)
        project_id = credentials.project_id

        vertexai.init(project=project_id, location="us-central1", credentials=credentials)
        gemini_model = GenerativeModel("gemini-pro-vision")
        generation_config = GenerationConfig(
            temperature=0.0,
            top_k=32,
            candidate_count=1,
            max_output_tokens=2000
        )
        safety_settings = {
            HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
            HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
            HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
            HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
        }
        responses = gemini_model.generate_content(
            contents=gemini_contents,
            stream=False,
            generation_config=generation_config,
            safety_settings=safety_settings)
        print(responses)
        print(responses.text)
        return responses.text
    except Exception as e:
        print(f"Exception occurred while making gemini call for index:{indexe} due to {e}")
        return ""

     images = [image1, image2, image3, image4, image3]
    indexs = [0, 1, 2, 3, 4]

        with ThreadPoolExecutor(10) as executor:
        answer_future = executor.submit(call_gemini_to_generate_summary, image1, 0)
        results = executor.map(call_gemini_to_generate_summary, images, indexs)
        print(answer_future.result(timeout=1))
        for result in results:
            print(result)
AniketModi commented 2 months ago

Can someone from gemini pls help to look at this issue and help for the fix.

jk1333 commented 2 months ago

Currently I'm using below multiprocessing(not thread) for testing.

import streamlit as st
import multiprocessing

@st.cache_resource
def get_processpool():
    return multiprocessing.Pool(multiprocessing.cpu_count() - 1)

@st.cache_resource
def analyze_text_gemini_pro(input, param1 = "NONE", param2 = "NONE"):
    from vertexai.generative_models import GenerativeModel, HarmCategory, HarmBlockThreshold
    def get_model():
        return GenerativeModel("gemini-1.0-pro-001")
    response = get_model().generate_content(
        input,
        generation_config={
            "candidate_count": 1,
            "max_output_tokens": 8192,
            "temperature": 0,
            "top_p": 0.5,
            "top_k": 1
        },
        safety_settings={
            HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE,
            HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
            HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
            HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
            HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
        },
        stream=False
    )
    #print(f"=== Number of Tokens === \n{response._raw_response.usage_metadata}\n===")
    #print(response)
    return response.text

def llm_tasks(indexed_item):
    idx, item = indexed_item
    return (idx, (item[0](*item[1])))

if __name__ == "__main__":
    st.set_page_config(
        page_title="Multiprocess test",
        layout="wide", 
        initial_sidebar_state="auto"
    )

    progress = st.empty()

    items = {(analyze_text_gemini_pro, ('Hello there', 'Sending...1')): st.container(), 
             (analyze_text_gemini_pro, ('How are you?', 'Sending...2')): st.container(), 
             (analyze_text_gemini_pro, ('Who are you?', 'Sending...3')): st.container(), 
             (analyze_text_gemini_pro, ('What time is it now?', 'Sending...4', "Param2")): st.container(), 
             (analyze_text_gemini_pro, ('Do you have time?', 'Sending...5', "Param2")): st.container()}

    bar = progress.progress(0)
    for idx, (work_idx, result) in enumerate(get_processpool().imap_unordered(llm_tasks, enumerate(items.keys())), 1):
        list(items.values())[work_idx].text_area(f"Output {work_idx}", f"{result}")
        bar.progress(idx / len(items))
    progress.empty()
jk1333 commented 2 months ago

Weird thing is, when I tested linux provided by google (like Cloud Run, Cloud Shell) or windows works well. But other public linux had problems. That probably comes from socket and event layer abstracted by OS make this differences.

soocheolnoh commented 1 month ago

In my case, the problem was with _prediction_client. Before using the model instance in a multithreading environment, simply accessing the property could help.

model = GenerativeModel()
_ = model._prediction_client
mutonby commented 3 weeks ago

I have a problem similar with socket io, with stream=True the api is freeze and it doesn't print the logs or do anything, the process simply freezes, I've tried to put stack trace but I don't see any help, it's very strange, if someone can give us a hand