VSAnimator / Sketch-a-Sketch

Controlling diffusion-based image generation with just a few strokes
https://vsanimator.github.io/sketchasketch/
MIT License
55 stars 1 forks source link

Command line interface #3

Open dataf3l opened 10 months ago

dataf3l commented 10 months ago

For those who like command line interfaces, and dislike gradio and it's timeouts, and have a M1 CPU, which does things slowly, you can use this code example:

import argparse
import torch
import numpy as np
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, EulerAncestralDiscreteScheduler
from diffusers.utils import load_image
from controlnet_aux import HEDdetector
from PIL import Image

def sketch(prompt, curr_sketch_path, output_path, negative_prompt="", num_steps=20, seed=None):
    # Set up device and models
    device = torch.device('cpu')
    controlnet = ControlNetModel.from_pretrained("vsanimator/sketch-a-sketch", torch_dtype=torch.float32).to(device)
    pipe = StableDiffusionControlNetPipeline.from_pretrained(
        "runwayml/stable-diffusion-v1-5", 
        controlnet=controlnet, torch_dtype=torch.float32
    ).to(device)
    pipe.safety_checker = None
    pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
    hed = HEDdetector.from_pretrained('lllyasviel/Annotators') # ControlNet

    if not seed:
        seed = np.random.randint(1000)

    generator = torch.Generator(device=device)
    generator.manual_seed(seed)

    # Load the current sketch image
    curr_sketch_image = Image.open(curr_sketch_path).convert("L").resize((512, 512))

    # Run function call
    images = pipe(prompt, curr_sketch_image.convert("RGB").point(lambda p: 256 if p > 128 else 0), 
                  negative_prompt=negative_prompt, 
                  num_inference_steps=num_steps, 
                  generator=generator, 
                  controlnet_conditioning_scale=1.0).images

    # Save the output image
    images[0].save(output_path)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Sketching CLI tool')
    parser.add_argument('--prompt', type=str, required=True, help='Input prompt for the model')
    parser.add_argument('--curr_sketch_path', type=str, required=True, help='Path to the current sketch image')
    parser.add_argument('--output_path', type=str, required=True, help='Path to save the generated image')
    parser.add_argument('--negative_prompt', type=str, default="", help='Negative prompt for the model')
    parser.add_argument('--seed', type=int, default=None, help='Seed for the generator')
    args = parser.parse_args()

    sketch(args.prompt, args.curr_sketch_path, args.output_path, args.negative_prompt, seed=args.seed)

I don't know if the negative prompt argument works, perhaps it was hallucinated.

dataf3l commented 10 months ago

using this command:

python cli.py --prompt "a potato " --curr_sketch_path "./example2.png" --output_path "./image.jpg"

It converts this:

https://i.imgur.com/53cQf7a.png

into this:

https://i.imgur.com/k9p8noW.jpg

which is "good enough for me"

admittedly, this potato is not the ultimate work of art, but hey, I did like, zero effort, so congratulations guys, you did a perfect work. I love it.

looking forward to the StableDiffusion XL version of this thing.