google-gemini / generative-ai-python

The official Python library for the Google Gemini API
https://pypi.org/project/google-generativeai/
Apache License 2.0
1.58k stars 314 forks source link

Docs fail to document thread-safety of genai client, and it fails irrecoverably on multi-threaded use #211

Closed jpdaigle closed 1 month ago

jpdaigle commented 8 months ago

Description of the bug:

The Generative Service Client or GenerativeModel classes don't document thread safety assumptions, and don't appear to be usable in a multithreaded environment for making concurrent API requests.

I'd suggest either:

Behaviour observed: After trying to make concurrent calls to the generative text api, most calls failed with a 60s timeout. The client never recovered (that is, every new call attempt also froze for 60s then ultimately timed out with an error).

Sample error output:

 10%|▉         | 199/2047.0 [29:31<5:46:51, 11.26s/it]
HTTPConnectionPool(host='localhost', port=46423): Read timed out. (read timeout=60.0)
 10%|▉         | 204/2047.0 [30:22<4:59:27,  9.75s/it]
('Connection aborted.', RemoteDisconnected('Remote end closed connection without response'))
('Connection aborted.', RemoteDisconnected('Remote end closed connection without response'))
 10%|█         | 209/2047.0 [31:10<6:08:26, 12.03s/it]
('Connection aborted.', RemoteDisconnected('Remote end closed connection without response'))
('Connection aborted.', RemoteDisconnected('Remote end closed connection without response'))
 11%|█         | 216/2047.0 [31:43<3:43:00,  7.31s/it]
('Connection aborted.', RemoteDisconnected('Remote end closed connection without response'))
('Connection aborted.', RemoteDisconnected('Remote end closed connection without response'))
 11%|█         | 225/2047.0 [32:48<3:52:42,  7.66s/it]
('Connection aborted.', RemoteDisconnected('Remote end closed connection without response'))
 11%|█▏        | 231/2047.0 [33:38<4:22:00,  8.66s/it]
('Connection aborted.', RemoteDisconnected('Remote end closed connection without response'))
('Connection aborted.', RemoteDisconnected('Remote end closed connection without response'))
 12%|█▏        | 245/2047.0 [35:55<6:14:28, 12.47s/it]
HTTPConnectionPool(host='localhost', port=46423): Read timed out. (read timeout=60.0)
('Connection aborted.', RemoteDisconnected('Remote end closed connection without response'))
 14%|█▍        | 296/2047.0 [43:38<4:30:46,  9.28s/it]
HTTPConnectionPool(host='localhost', port=46423): Read timed out. (read timeout=60.0)

Example snippet:

# [... regular imports ...]
from concurrent.futures import ThreadPoolExecutor
import tqdm

safety_settings = ...
executor = ThreadPoolExecutor(max_workers=5)

def build_data_batch():
  ## build batches of data to process
  pass

def generate(data_batch):
  model_out = 'error'
  try:
    # this ends up failing whether or not the model client is 
    # created freshly per-request, or shared across threads

    model = genai.GenerativeModel('models/gemini-pro', safety_settings=safety_settings)
    model_out = model.generate_content(build_prompt(data_batch)).text
  except Exception as e:
    print(e)
  return model_out

all_outputs = []
all_outputs = executor.map(generate, build_data_batch())

with open('./outputs.txt', 'w') as f:
  for result in tqdm.tqdm(all_outputs, total=totalbatches):
    f.write(result)

Actual vs expected behavior:

Actual: all calls fail.

Expected: this case should either work, or client docs should document as non-thread-safe for concurrent usage given how common batch inference scenarios are likely to be.

Any other information you'd like to share?

No response

adenalhardan commented 6 months ago

Hey blocked on the same issue - did you manage to find any sort of workaround?

jpdaigle commented 6 months ago

@adenalhardan Yes, instead of using the API client library, I ended up directly using the HTTP API endpoint. Using independent HTTP clients (1 per worker thread), then no problems.

adenalhardan commented 6 months ago

Awesome thank you

RukshanJS commented 5 months ago

@adenalhardan Yes, instead of using the API client library, I ended up directly using the HTTP API endpoint. Using independent HTTP clients (1 per worker thread), then no problems.

Can you kindly share some general code how you did that?

Oh wait - it's in the docs.. Figured it out! thanks https://ai.google.dev/api/rest/v1beta/media/upload?hl=en This works

UPDATE - Halfway through the upload, I get this thoughError processing batch: ('Connection aborted.', RemoteDisconnected('Remote end closed connection without response'))

adenalhardan commented 5 months ago

@RukshanJS Here's what I've been using, maybe it'll fix your error

import base64
import json
import requests
import google.auth
import google.auth.transport.requests
from google.oauth2 import service_account

class GoogleHTTPClient:
  def __init__(self, model: str ='gemini-1.5-pro-preview-0409', max_retries: int = 5):
    self.model = model
    self.max_retries = max_retries
    self.project = 'YOUR PROJECT ID'
    self.region = 'YOUR REGION'

    self.access_token = self._get_access_token()

  def request_message(self, messages: list[object], retries: int = 0) -> str:
    url = f"https://{self.region}-aiplatform.googleapis.com/v1/projects/{self.project}/locations/{self.region}/publishers/google/models/{self.model}:generateContent"

    headers = {
      'Authorization': f'Bearer {self.access_token}',
      'Content-Type': 'application/json',
    }

    data = json.dumps({ 
      'contents': {
        'role': 'user',
        'parts': [messages]
      } 
    })

    response = requests.post(url, headers=headers, data=data)
    response = json.loads(response.text)

    try:
      response = response['candidates'][0]['content']['parts'][0]['text'].strip()
      return response
    except Exception as error:
      if retries == self.max_retries:
        raise Exception(f'Failed to fetch from google:', error)

      return self.request_message(messages, retries + 1)

  def format_image_message(self, image: bytes) -> object:
    return {
      "inlineData": {
        "mimeType": 'image/png',
        "data": base64.b64encode(image).decode('utf-8')
      }
    }

  def format_text_message(self, text: str) -> object:
    return { 'text': text }

  def _get_access_token(self) -> str:
    service_account_key = './google_credentials.json'
    credentials = service_account.Credentials.from_service_account_file(
        service_account_key, 
        scopes=['https://www.googleapis.com/auth/cloud-platform'] 
    )

    auth_req = google.auth.transport.requests.Request()
    credentials.refresh(auth_req)

    return credentials.token
RukshanJS commented 5 months ago

@RukshanJS Here's what I've been using, maybe it'll fix your error

import base64
import json
import requests
import google.auth
import google.auth.transport.requests
from google.oauth2 import service_account

class GoogleHTTPClient:
  def __init__(self, model: str ='gemini-1.5-pro-preview-0409', max_retries: int = 5):
    self.model = model
    self.max_retries = max_retries
    self.project = 'YOUR PROJECT ID'
    self.region = 'YOUR REGION'

    self.access_token = self._get_access_token()

  def request_message(self, messages: list[object], retries: int = 0) -> str:
    url = f"https://{self.region}-aiplatform.googleapis.com/v1/projects/{self.project}/locations/{self.region}/publishers/google/models/{self.model}:generateContent"

    headers = {
      'Authorization': f'Bearer {self.access_token}',
      'Content-Type': 'application/json',
    }

    data = json.dumps({ 
      'contents': {
        'role': 'user',
        'parts': [messages]
      } 
    })

    response = requests.post(url, headers=headers, data=data)
    response = json.loads(response.text)

    try:
      response = response['candidates'][0]['content']['parts'][0]['text'].strip()
      return response
    except Exception as error:
      if retries == self.max_retries:
        raise Exception(f'Failed to fetch from google:', error)

      return self.request_message(messages, retries + 1)

  def format_image_message(self, image: bytes) -> object:
    return {
      "inlineData": {
        "mimeType": 'image/png',
        "data": base64.b64encode(image).decode('utf-8')
      }
    }

  def format_text_message(self, text: str) -> object:
    return { 'text': text }

  def _get_access_token(self) -> str:
    service_account_key = './google_credentials.json'
    credentials = service_account.Credentials.from_service_account_file(
        service_account_key, 
        scopes=['https://www.googleapis.com/auth/cloud-platform'] 
    )

    auth_req = google.auth.transport.requests.Request()
    credentials.refresh(auth_req)

    return credentials.token

This is very helpful thanks a lot!. Do you mind sharing the usage of this client as well. I have a little bit of trouble understanding why we have to use an access token here. Can't we use the Gemini API itself using the key? without using the genai client?

RukshanJS commented 5 months ago

My current code is,


def upload_file(file_path):
    return genai.upload_file(pathlib.Path(file_path))

def analyze_frames(frame_paths):
    model = genai.GenerativeModel("gemini-1.5-pro")
    prompt = """
Give me a description after looking at these images
"""

    # Upload the images using the File API
    logger.info("Uploading files...")
    file_references = []

    with concurrent.futures.ProcessPoolExecutor() as executor:
        futures = [
            executor.submit(upload_file, frame_path) for frame_path in frame_paths
        ]
        for future in tqdm(
            concurrent.futures.as_completed(futures),
            total=len(futures),
            desc="Uploading files",
        ):
            file_references.append(future.result())

    logger.info("Uploading files completed...")

    # Generate content using the file references
    logger.info("Making inferences using the model...", model.model_name)
    response = model.generate_content(
        [
            prompt,
            *file_references,
        ]
    )
    return response.text

def process_batches(frames, batch_size):
    # Split the frames into batches
    batch_list = [frames[i : i + batch_size] for i in range(0, len(frames), batch_size)]
    batch_results = []

    with ThreadPoolExecutor() as executor:
        futures = {
            executor.submit(analyze_frames, batch): batch for batch in batch_list
        }

        for future in tqdm(
            as_completed(futures),
            total=len(futures),
            desc="Processing batches",
        ):
            batch_results.append(future.result())

    return batch_results

I call this as, result = process_batches(frame_paths, 3550)

Problem

It uploads parallel correctly for sometime, and then hangs (doesn't progress after this)

Uploading files:   4%|████▍                             155/3550 [00:53<19:29,  2.90it/s]
Uploading files:   9%|████████▊              158/1850 [00:55<09:53,  2.85it/s]

Observations

  1. It always fails around at ~160 images in each batch process
MarkDaoust commented 1 month ago

Hey everybody,

I think @jpdaigle's original issue is resolved.

Aside from 429 Quota exceeded I haven't gotten any errors. I added request_options=dict(retry=retry.Retry(timeout=600)) to allow retries with along timeout.

The retries may help with a lot of other errors, which seem much less common now.

https://gist.github.com/MarkDaoust/dcd65b626bf4683860aa510b79bc225e

So I think this bug is fixed.


@RukshanJS I've heard other reports of failures with threading specifically for file uploads.

How about we continue this in https://github.com/google-gemini/generative-ai-python/issues/327