replicate / replicate-python

Python client for Replicate
https://replicate.com
Apache License 2.0
770 stars 222 forks source link

Failed image generation returns empty list #397

Open cellwebb opened 1 week ago

cellwebb commented 1 week ago

Backstory: I created a discord bot for my friend group's server that amongst other things can take an image and prompt from a specific channel, generate a new image using sd-3.5-large-turbo, and post it back to the same channel. Today, one of my friends shared the image below with the prompt "leg day" and instead of returning a new image, replicate returned an empty list. I tried some other images and prompts which worked, so I looked at my replicate dashboard and it says that it failed because nsfw content was detected in the generated image. Weird, but okay.

image

Request: Rather than just an empty list, can an exception be raised or something be returned indicating why the prediction failed?

Relevant code:

params = {
  "model": "stability-ai/stable-diffusion-3.5-large-turbo",
  "cfg": 1,
  "steps": 4,
  "prompt_strength": 0.8,
}

filetype = attachment.filename.split(".")[-1]
os.makedirs("images/variations/originals", exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
original_filepath = f"images/variations/originals/{timestamp}.{filetype}"
await attachment.save(original_filepath)

with Image.open(original_filepath) as img:
    buff = BytesIO()
    img.save(buff, format="PNG")
    img_str = base64.b64encode(buff.getvalue()).decode("utf-8")
    image = f"data:application/octet-stream;base64,{img_str}"

response = await replicate.async_run(
    params["model"],
    input={
        "image": image,
        "prompt": prompt,
        "cfg": params["cfg"],
        "steps": params["steps"],
        "prompt_strength": params["prompt_strength"],
    },
)

if isinstance(response, list):
    response = response[0]
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
variation_filepath = f"images/variations/{timestamp}.png"
with open(variation_filepath, "wb") as f:
    f.write(response.read())

await message.channel.send(f":D", file=discord.File(variation_filepath))
aron commented 1 week ago

Thanks for the report, the replicate.run() call should raise a ModelError if the prediction fails for NSFW error or otherwise. Let me take a look at the model you've provided and see if there's something off.