damian0815 / compel

A prompting enhancement library for transformers-type text embedding systems
MIT License
526 stars 47 forks source link

Stable Cascade support, new ReturnedEmbeddingsType #104

Open Teriks opened 2 months ago

Teriks commented 2 months ago

Stable Cascade support in EmbeddingsProvider via new returned embeddings type.

Usage:


import gc
import torch
import compel

from diffusers import StableCascadePriorPipeline, StableCascadeDecoderPipeline

device = 'cuda'

prompt = "an image of a shiba inu with (blue eyes)1.4, donning a green+ spacesuit, (cartoon style)1.6"
neg_prompt = "photograph, (real)1.6"

generator = torch.Generator(device=device).manual_seed(0)

# prior

prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", variant='bf16',
                                                   torch_dtype=torch.bfloat16).to(device)

prior_compel = compel.Compel(tokenizer=prior.tokenizer,
                             text_encoder=prior.text_encoder,
                             requires_pooled=True,
                             device=device,
                             returned_embeddings_type=compel.ReturnedEmbeddingsType.STABLE_CASCADE)

conditioning, pooled = prior_compel(prompt)
neg_conditioning, neg_pooled = prior_compel(neg_prompt)

conditioning, neg_conditioning = \
    prior_compel.pad_conditioning_tensors_to_same_length([conditioning, neg_conditioning])

prior_output = prior(
    num_inference_steps=20,
    guidance_scale=4,
    prompt_embeds=conditioning,
    prompt_embeds_pooled=pooled,
    negative_prompt_embeds=neg_conditioning,
    negative_prompt_embeds_pooled=neg_pooled,
    generator=generator
)

del conditioning, pooled, neg_conditioning, neg_pooled
prior.to('cpu')

# decoder

decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", variant='bf16',
                                                       torch_dtype=torch.float16).to(device)

decoder_compel = compel.Compel(tokenizer=decoder.tokenizer,
                               text_encoder=decoder.text_encoder,
                               device=device,
                               requires_pooled=True,
                               returned_embeddings_type=compel.ReturnedEmbeddingsType.STABLE_CASCADE)

conditioning, pooled = decoder_compel(prompt)
neg_conditioning, neg_pooled = decoder_compel(neg_prompt)

conditioning, neg_conditioning = \
    decoder_compel.pad_conditioning_tensors_to_same_length([conditioning, neg_conditioning])

image = decoder(
    num_inference_steps=10,
    guidance_scale=0.0,
    prompt_embeds=conditioning,
    prompt_embeds_pooled=pooled,
    negative_prompt_embeds=neg_conditioning,
    negative_prompt_embeds_pooled=neg_pooled,
    image_embeddings=prior_output.image_embeddings.half(),
    generator=generator
).images[0]

image.save('test.png')

del conditioning, pooled, neg_conditioning, neg_pooled

decoder.to('cpu')
gc.collect()
torch.cuda.empty_cache()

Output Example:

test