replicate / dreambooth-template

A template repo for training and publishing your own custom Stable Diffusion model using https://replicate.com/replicate/dreambooth
Apache License 2.0
51 stars 10 forks source link

support for longer prompts #26

Open anotherjesse opened 1 year ago

anotherjesse commented 1 year ago

Users have asked for support for longer prompts

This might be the answer https://github.com/huggingface/diffusers/tree/main/examples/community#stable-diffusion-mega

anotherjesse commented 1 year ago

@daanelson did an initial proof of concept with:

            custom_pipeline="lpw_stable_diffusion",

https://github.com/daanelson/cog-stable-diffusion-long-prompt

And pushed a version to https://replicate.com/daanelson/stable-diffusion-long-prompts

He shares that this pipeline is slower even for short prompts, so it might not be a good default.

He shared that we should be able to share the VAE/other sub-models a "long prompt" version and a "normal" version similar to how we currently share the weights between img2img and txt2img:

https://github.com/replicate/dreambooth-template/blob/main/predict.py#L63-L78

    self.txt2img_pipe = StableDiffusionPipeline.from_pretrained(
        "weights",
        safety_checker=self.safety_checker,
        feature_extractor=feature_extractor,
        torch_dtype=torch.float16,
    ).to("cuda")

    self.img2img_pipe = StableDiffusionImg2ImgPipeline(
        vae=self.txt2img_pipe.vae,
        text_encoder=self.txt2img_pipe.text_encoder,
        tokenizer=self.txt2img_pipe.tokenizer,
        unet=self.txt2img_pipe.unet,
        scheduler=self.txt2img_pipe.scheduler,
        safety_checker=self.txt2img_pipe.safety_checker,
        feature_extractor=self.txt2img_pipe.feature_extractor,
    ).to("cuda")

But we would need to validate this first