replicate / replicate-python

Python client for Replicate
https://replicate.com
Apache License 2.0
692 stars 193 forks source link

Ability to process in Batches #299

Open charliemday opened 2 months ago

charliemday commented 2 months ago

Is there any ability to batch process similar to how OpenAI do it: https://help.openai.com/en/articles/9197833-batch-api-faq?

mattt commented 2 months ago

Hi @charliemday. No, Replicate doesn't currently implement a batch processing API like OpenAI. It's something we're considering, though. Can you share more about your intended use case?

charliemday commented 2 months ago

Hi @mattt, sure.

My intended use case is that I have a CSV file of ~1k rows and I want to to send a request for each row. Given that that the response takes ~1-2 seconds this is going to take ~20 minutes give or take. I would like to batch the rows into groups of 10 and send them all at once taking the time to completion down considerably.

mattt commented 2 months ago

@charliemday Replicate does support creating up to 6000 concurrent predictions per minute. Depending on how much the model is scaled out, you could process all of them more quickly using our async API. Here's an example from the README that you can adapt (I'd recommend processing rows in async batches of 100 or so, and keeping track of successful and failing rows:

import asyncio
import replicate

# https://replicate.com/stability-ai/sdxl
model_version = "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b"
prompts = [
    f"A chariot pulled by a team of {count} rainbow unicorns"
    for count in ["two", "four", "six", "eight"]
]

async with asyncio.TaskGroup() as tg:
    tasks = [
        tg.create_task(replicate.async_run(model_version, input={"prompt": prompt}))
        for prompt in prompts
    ]

results = await asyncio.gather(*tasks)
print(results)
charliemday commented 2 months ago

@charliemday Replicate does support creating up to 6000 concurrent predictions per minute. Depending on how much the model is scaled out, you could process all of them more quickly using our async API. Here's an example from the README that you can adapt (I'd recommend processing rows in async batches of 100 or so, and keeping track of successful and failing rows:

import asyncio
import replicate

# https://replicate.com/stability-ai/sdxl
model_version = "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b"
prompts = [
    f"A chariot pulled by a team of {count} rainbow unicorns"
    for count in ["two", "four", "six", "eight"]
]

async with asyncio.TaskGroup() as tg:
    tasks = [
        tg.create_task(replicate.async_run(model_version, input={"prompt": prompt}))
        for prompt in prompts
    ]

results = await asyncio.gather(*tasks)
print(results)

Thanks @mattt , this should work for my use case.

Maybe I'm missing something but I don't see 6000 RPM on the link (only 600)?

mattt commented 2 months ago

@charliemday Apologies, yes — the rate limit is 600 / minute. That was a typo on my part.