damian0815 / compel

A prompting enhancement library for transformers-type text embedding systems
MIT License
486 stars 42 forks source link

Stable Cascade? #95

Open Teriks opened 4 days ago

Teriks commented 4 days ago

It seems like it might be possible for this to work with stable cascade?

I am wondering if there is a working snippet for prior + decoder or if it is incompatible at the moment.

Teriks commented 4 days ago

This generates a recognizable image, though given the quality of the image, there is definitely something missing from the equation somewhere. Though it seems somewhat possible.


import torch
from compel import Compel, ReturnedEmbeddingsType
from diffusers import StableCascadePriorPipeline, StableCascadeDecoderPipeline

device = 'cuda'
prompt = "an image of a (shiba inu)1.5 donning a spacesuit++"

prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", variant='bf16',
                                                   torch_dtype=torch.bfloat16).to(device)

prior_compel = Compel(tokenizer=prior.tokenizer,
                      text_encoder=prior.text_encoder,
                      returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
                      requires_pooled=True, device=device)

conditioning, pooled = prior_compel(prompt)

prior_output = prior(
    num_inference_steps=20,
    guidance_scale=4,
    prompt_embeds=conditioning,
    prompt_embeds_pooled=pooled.unsqueeze(1))

prior.to('cpu')

decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", variant='bf16',
                                                       torch_dtype=torch.float16).to(device)

decoder_compel = Compel(tokenizer=decoder.tokenizer,
                        text_encoder=decoder.text_encoder,
                        returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
                        requires_pooled=True,
                        device=device)

conditioning, pooled = decoder_compel(prompt)

decoder(num_inference_steps=10,
        guidance_scale=0.0,
        prompt_embeds=conditioning,
        prompt_embeds_pooled=pooled.unsqueeze(1),
        image_embeddings=prior_output.image_embeddings.half()).images[0].save('test.png')