replicate / cog-flux

Cog inference for flux models
https://replicate.com/black-forest-labs/flux-dev
Apache License 2.0
272 stars 28 forks source link

Flux-Dev img2img #6

Closed lucataco closed 2 months ago

lucataco commented 2 months ago

Img2img for flux-dev

cog predict -i prompt="a tiny astronaut hatching from an egg on the moon"

cog predict -i prompt="a tiny astronaut hatching from a purple egg on the moon" -i image=@txt2img.png

zsxkib commented 2 months ago

@lucataco Hey, about the NSFW checking in flux-dev:

Right now, if any image is NSFW, we throw out all images. Maybe we could:

  1. Check each image separately
  2. Keep the safe ones, skip the NSFW ones
  3. Return all safe images to the user
  4. Maybe tell the user how many images we filtered out

This way, users get some images even if a few are NSFW. What do you think?


Here's what you could try: Current code:

output_paths = []
for i in range(num_outputs):
    # bring into PIL format and save
    img = rearrange(x[i], "c h w -> h w c").clamp(-1, 1)
    img = Image.fromarray((127.5 * (img + 1.0)).cpu().byte().numpy())

    if not disable_safety_checker:
        if self.offload:
            self.safety_checker = self.safety_checker.to("cuda")
        _, has_nsfw_content = self.run_safety_checker(img)
        if has_nsfw_content[0]:
            raise Exception(f"NSFW content detected in image {i+1}. Try running it again, or try a different prompt.")
...
return output_paths

Modified code:

output_paths = []
nsfw_count = 0
for i in range(num_outputs):
    # bring into PIL format and save
    img = rearrange(x[i], "c h w -> h w c").clamp(-1, 1)
    img = Image.fromarray((127.5 * (img + 1.0)).cpu().byte().numpy())

    is_nsfw = False
    if not disable_safety_checker:
        if self.offload:
            self.safety_checker = self.safety_checker.to("cuda")
        _, has_nsfw_content = self.run_safety_checker(img)
        if has_nsfw_content[0]:
            is_nsfw = True
            nsfw_count += 1
            print(f"NSFW content detected in image {i+1}. This image will not be returned.")

    if not is_nsfw:
        output_path = f"out-{i}.{output_format}"
        if output_format != 'png':
            img.save(output_path, quality=output_quality, optimize=True)
        else:
            img.save(output_path)
        output_paths.append(Path(output_path))

emit_metric("num_images", len(output_paths))
if nsfw_count > 0:
    print(f"Total {nsfw_count} image(s) were filtered out due to NSFW content.")
if len(output_paths) == 0:
    print("All generated images contained NSFW content. Try running it again with a different prompt.")
return output_paths

Let me know if you want to chat about this!