replicate / replicate-python

Python client for Replicate
https://replicate.com
Apache License 2.0
744 stars 212 forks source link

Setting timeouts in for replicate.run() #272

Closed darhsu closed 2 weeks ago

darhsu commented 6 months ago

Hello, I was wondering how you can set timeouts in the replicate.run() function.

I have tried using the replicate client but it didn't throw a timeout error:

from replicate.client import Client

replicate_client = Client(api_token="my_api_token", timeout=1)
replicate_client.run(
        "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b",
        input={"prompt": "a photo of an astronaut riding a horse on Mars"},
    )

When I printed the timeout, it correctly displayed the timeout value.

print(replicate_client._timeout)
>>> 1
mattt commented 6 months ago

Hi @darhsu. The timeout parameter is passed to the underlying httpx client instance, and configures the timeout for initially creating a prediction and each subsequent request to poll for its completion. All of these are likely to take less than a second.

In general, I'd recommend against this kind of approach. Instead, you should try calling cancel on any predictions that haven't completed before some deadline.

Please note that the model you're running, SDXL typically runs in a few seconds. So a one second timeout would almost always fail to produce results.

darhsu commented 5 months ago

@mattt Thanks for your help. Do you have any pointers on canceling async runs?

This is what I have so far, I've added a 60 second timeout to the sample code here.

import asyncio
import replicate

# https://replicate.com/stability-ai/sdxl
model_version = "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b"
prompts = [
    f"A chariot pulled by a team of {count} rainbow unicorns"
    for count in ["two", "four", "six", "eight"]
]

async with asyncio.TaskGroup() as tg:
    tasks = [
        tg.create_task(replicate.async_run(model_version, input={"prompt": prompt}))
        for prompt in prompts
    ]

try:
    async with asyncio.timeout(60):
        results = await asyncio.gather(*tasks)
except TimeoutError:
    # Cancel replicate async run

print(results)
mattt commented 2 weeks ago

@darhsu Sorry for the late response.

To cancel a prediction after a given timeout, you'll need to break up replicate.async_run into two steps: replicate.predictions.async_create and replicate.async_wait:

import asyncio
import replicate

# https://replicate.com/stability-ai/sdxl
model_version = "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b"
prompts = [
    f"A chariot pulled by a team of {count} rainbow unicorns"
    for count in ["two", "four", "six", "eight"]
]

async def run_prediction(prompt):
    prediction = await replicate.async_predictions.create(
        version=model_version,
        input={"prompt": prompt}
    )
    return prediction

tasks = [run_prediction(prompt) for prompt in prompts]

try:
    async with asyncio.timeout(60):
        predictions = await asyncio.gather(*tasks)

        # Wait for all predictions to complete
        await asyncio.gather(*[prediction.async_wait() for prediction in predictions])

        results = [prediction.output for prediction in predictions]
        print(results)
except TimeoutError:
    print("Timeout occurred. Canceling predictions...")
    # Cancel all running predictions
    await asyncio.gather(*[prediction.async_cancel() for prediction in predictions if prediction.status not in ["succeeded", "failed", "canceled"]])
    print("Predictions canceled.")