Closed darhsu closed 2 weeks ago
Hi @darhsu. The timeout
parameter is passed to the underlying httpx client instance, and configures the timeout for initially creating a prediction and each subsequent request to poll for its completion. All of these are likely to take less than a second.
In general, I'd recommend against this kind of approach. Instead, you should try calling cancel
on any predictions that haven't completed before some deadline.
Please note that the model you're running, SDXL typically runs in a few seconds. So a one second timeout would almost always fail to produce results.
@mattt Thanks for your help. Do you have any pointers on canceling async runs?
This is what I have so far, I've added a 60 second timeout to the sample code here.
import asyncio
import replicate
# https://replicate.com/stability-ai/sdxl
model_version = "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b"
prompts = [
f"A chariot pulled by a team of {count} rainbow unicorns"
for count in ["two", "four", "six", "eight"]
]
async with asyncio.TaskGroup() as tg:
tasks = [
tg.create_task(replicate.async_run(model_version, input={"prompt": prompt}))
for prompt in prompts
]
try:
async with asyncio.timeout(60):
results = await asyncio.gather(*tasks)
except TimeoutError:
# Cancel replicate async run
print(results)
@darhsu Sorry for the late response.
To cancel a prediction after a given timeout, you'll need to break up replicate.async_run
into two steps: replicate.predictions.async_create
and replicate.async_wait
:
import asyncio
import replicate
# https://replicate.com/stability-ai/sdxl
model_version = "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b"
prompts = [
f"A chariot pulled by a team of {count} rainbow unicorns"
for count in ["two", "four", "six", "eight"]
]
async def run_prediction(prompt):
prediction = await replicate.async_predictions.create(
version=model_version,
input={"prompt": prompt}
)
return prediction
tasks = [run_prediction(prompt) for prompt in prompts]
try:
async with asyncio.timeout(60):
predictions = await asyncio.gather(*tasks)
# Wait for all predictions to complete
await asyncio.gather(*[prediction.async_wait() for prediction in predictions])
results = [prediction.output for prediction in predictions]
print(results)
except TimeoutError:
print("Timeout occurred. Canceling predictions...")
# Cancel all running predictions
await asyncio.gather(*[prediction.async_cancel() for prediction in predictions if prediction.status not in ["succeeded", "failed", "canceled"]])
print("Predictions canceled.")
Hello, I was wondering how you can set timeouts in the replicate.run() function.
I have tried using the replicate client but it didn't throw a timeout error:
When I printed the timeout, it correctly displayed the timeout value.