googleapis / python-aiplatform

A Python SDK for Vertex AI, a fully managed, end-to-end platform for data science and machine learning.
Apache License 2.0
615 stars 328 forks source link

[LLM] When using Celery Python `from_pretrained` call hangs since SDK version 1.33.1 #2620

Closed LEAGUEDORA closed 10 months ago

LEAGUEDORA commented 11 months ago

Thanks for stopping by to let us know something could be better!

PLEASE READ: If you have a support contract with Google, please create an issue in the support console instead of filing on GitHub. This will ensure a timely response.

Please run down the following list and make sure you've tried the usual "quick fixes":

If you are still having issues, please be sure to include as much information as possible:

Environment details

Steps to reproduce

  1. Write a code in FastAPI and add celery to it
  2. Use the Vertex Chat/Text Model in celery

I will explain to you what we are doing. We have FastAPI 0.100.1 as our backend. In addition to that we are using celery with the exact same code base of FastAPI. When we upgraded the google-cloud-aiplatform to 1.33.1, FastAPI loaded fine, but celery is not loading. We debugged by writing a few print statements for every line in our code.

At exactly this point

chat_model = ChatModel.from_pretrained("chat-bison@001")

the code is pausing there. It is neither throwing an error nor killing itself. But the exact same code, same versions, is working fine with FastAPI

This is very strange. It works fine when we downgrade the package to 1.32.0

LEAGUEDORA commented 11 months ago

Here is my requirements file:

pymongo<=4.5.0
fastapi<=0.100.1
python-jose[cryptography]<=3.3.0
passlib[bcrypt]<=1.7.4
python-multipart<=0.0.6
uvicorn<=0.23.2
requests<=2.31.0
regex<=2023.6.3
pydantic==1.10.9
redis[hiredis]
num2words<=0.5.12
nest_asyncio<=1.5.7
pytz<=2023.3
python-dateutil<=2.8.2
librosa<=0.10.0.post2
soundfile<=0.12.1
openai<=0.27.8
celery<=5.3.1
deepgram-sdk==2.8.0
nvidia-riva-client
websockets<=11.0.3
google-cloud-aiplatform==1.32.0
httpx==0.24.1
pytest==7.4.0
pyyaml<=6.0.1
boto3<=1.24.37
twilio<=6.50.1
aiofiles<=23.2.1
aiologger<=0.7.0
gevent<=23.7.0
Ark-kun commented 11 months ago

Thank you for the bug report. This looks really puzzling. Is it possible to get the stack trace when it's hanging? For example, you can use debugger or interrupt the thread to get the stack trace.

If you cannot get stack trace, then can you check whether you can create an instance of aiplatform.Endpoint(name=...) and see whether it hangs.

LEAGUEDORA commented 11 months ago

Thank for @Ark-kun for checking this. But in the Vs Code debugger, the threads are not even showing up.

Can you tell what should be end point name for the chat-bison@001 mode?

Ark-kun commented 11 months ago

But in the Vs Code debugger, the threads are not even showing up.

I'm not sure how you exactly you run the code. Can you give a minimal PoC?

When running code interactively in console or in Jupyter Notebook, you can stop the execution (e.g. Ctrl+C) and Python will raise KeyboardInterruptError and print the stack trace.

Can you tell what should be end point name for the chat-bison@001 mode?

It does not have a corresponding vertex Endpoint. Do you have any Vertex Endpoints that you can test on? Just the Endpoint object construction.

Ark-kun commented 11 months ago

Actually, try this:

aiplatform.Endpoint._construct_sdk_resource_from_gapic(
    aiplatform.models.gca_endpoint_compat.Endpoint(name="projects/<YOUR_PROJECT_ID>/locations/us-central1/publishers/google/models/text-bison@001"),
)
LEAGUEDORA commented 11 months ago

Even ctrl + c is not working. I tried for 5 times. Long pressed too!.

Actually, try this:

aiplatform.Endpoint._construct_sdk_resource_from_gapic(
    aiplatform.models.gca_endpoint_compat.Endpoint(name="projects/<YOUR_PROJECT_ID>/locations/us-central1/publishers/google/models/text-bison@001"),
)

I tried this. This also hangs there image

Ark-kun commented 11 months ago

Thank you for your help investigating. Let's pin it down further: It looks like the culprit is somewhere here: https://github.com/googleapis/python-aiplatform/commit/e9eb159756dfe90c9f72818204fa74d05096aec6#diff-72d6e2818d2224db618771b26094747754d49a2acf33160eaf017ff1b6ddf840R1415-R1429

Let's try these two snippets:

        import asyncio
        try:
            asyncio.get_event_loop()
        except RuntimeError as ex:
            print(ex)
            asyncio.set_event_loop(asyncio.new_event_loop())
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import utils

async_client = initializer.global_config.create_client(
    client_class=utils.PredictionAsyncClientWithOverride,
    prediction_client=True,
)

Which one of them causes the issue?

LEAGUEDORA commented 11 months ago

Should I edit in my local code?

jameschristopher commented 11 months ago

I don't think the actual issue is with celery, I think it is with gevent. I'm guessing you are running your celery with the gevent worker class. I am running into the same issue when calling the classic .predict method while running a gunicorn worker with the gevent class. I am just fielding a guess here that the issue is something in the newly defined async client code.

This is a classic failure mode with gevent incompatible code.

LEAGUEDORA commented 11 months ago

Hey @jameschristopher ,

I just checked with eventlet. The same issue is happening

jameschristopher commented 11 months ago

Yeah, that checks out since both eventlet and gevent are greenlet async libraries. In general you don't want to mix two different async methods. So gevent + asyncio or evenlet + asyncio would cause issues. The issue we're dealing with was likely introduced in this commit https://github.com/googleapis/python-aiplatform/commit/e9eb159756dfe90c9f72818204fa74d05096aec6 . These kinds of issues are hard to debug.

Ark-kun commented 11 months ago

The issue we're dealing with was likely introduced in this commit https://github.com/googleapis/python-aiplatform/commit/e9eb159756dfe90c9f72818204fa74d05096aec6 .

Yes. We're debugging it. That commit has added the following two pieces of code:

        import asyncio
        try:
            asyncio.get_event_loop()
        except RuntimeError as ex:
            print(ex)
            asyncio.set_event_loop(asyncio.new_event_loop())
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import utils
async_client = initializer.global_config.create_client(
    client_class=utils.PredictionAsyncClientWithOverride,
    prediction_client=True,
)

@LEAGUEDORA

Should I edit in my local code?

Yes. The same way you tested the aiplatform.Endpoint hanging.

Ark-kun commented 11 months ago

@LEAGUEDORA Were you able to test the asyncio snippet with Celery?

jameschristopher commented 11 months ago

I did some testing and it is this block of code that causes the issue.

try:
    asyncio.get_event_loop()
except RuntimeError as ex:
    print(ex)
    asyncio.set_event_loop(asyncio.new_event_loop())

When running under a greenlet based threading regime like gevent or eventlet, basically calling anything in asyncio will cause an error because the system is trying to run two different kinds of async regimes at the same time.

Looking at the code, it looks like the async prediction client is instantiated and returned whether it is used or not. From my understanding long as those few asyncio lines are run the library will be incompatible with anyone using gevent or eventlet.

Perhaps there is a way to lazily instantiated the prediction client in the predict_async and explain_async methods? Or at least check for an existing event loop in those methods?

Ark-kun commented 10 months ago

When running under a greenlet based threading regime like gevent or eventlet, basically calling anything in asyncio will cause an error because the system is trying to run two different kinds of async regimes at the same time.

It seems to be concerning that these systems conflict with a major builtin Python feature. Could these systems be fixed to play nicer with Python async or plug into it (the asyncio even loops seems to be extendable and pluggable). It might be possible to register those systems in a way that new_event_loop creates something native to those systems. Are there any existing solutions for this?

I do not think it's scalable to have per-system methods like:

def predict_async_asyncio(...)
def predict_async_gevent(...)
def predict_async_eventlet(...)