xhinker / sd_embed

Generate long weighted prompt embeddings for Stable Diffusion
Apache License 2.0
85 stars 9 forks source link

`get_weighted_text_embeddings_sdxl` Fails with Empty Prompt #5

Closed thangnd-zenai closed 4 months ago

thangnd-zenai commented 4 months ago

Hi @xhinker , I see that your pipeline encounters an error when run with:

prompt = ""
negative_prompt = ""

Bugs:

prompt_embeds = torch.cat(embeds, dim=1)
                    ^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: torch.cat(): expected a non-empty list of Tensors

How can this be fixed? Additionally, I sometimes encounter a bug in this line:

token_embedding = torch.concat(prompt_embeds_list, dim=-1).squeeze(0)

Bug:

Sizes of tensors must match except in dimension 2. Expected size 76 but got size 77 for tensor number 1 in the list.

For the second error, I haven't been able to trace exactly when it occurs because it appears when I run the product in production. I am still investigating the reasons for its occurrence. I am writing this issue hoping you can provide a solution to these problems if you have any ideas.

thangnd-zenai commented 4 months ago

This is my solution: group_tokens_and_weights function:

if len(token_ids) > 0: 

Modify to:

if len(token_ids) >= 0:

Other functions:

def pad_to_length(tokens, length, token_id):
    if len(tokens) < length:
        tokens += [token_id] * (length - len(tokens))
    return tokens

def get_embeds_from_encoder(encoder, token_tensor):
    output = encoder(token_tensor, output_hidden_states = True)
    return output.hidden_states[-2], output[0]

def get_prompt_embeddings(pipe, tokens1, tokens2, weights):
    token_tensor1 = torch.tensor([tokens1], dtype=torch.long, device=pipe.device)
    token_tensor2 = torch.tensor([tokens2], dtype=torch.long, device=pipe.device)
    weight_tensor = torch.tensor(weights, dtype=torch.float16, device=pipe.device)

    # Get embeddings from text_encoder and text_encoder_2
    prompt_embeds1, _ = get_embeds_from_encoder(pipe.text_encoder, token_tensor1)
    prompt_embeds2, pooled_prompt_embeds = get_embeds_from_encoder(pipe.text_encoder_2, token_tensor2)

    if prompt_embeds1.shape[1] != 77:
        prompt_embeds1 = torch.cat([
            prompt_embeds1, 
            torch.zeros(1, 77 - prompt_embeds1.shape[1], prompt_embeds1.shape[-1], 
                        device=prompt_embeds1.device, dtype=prompt_embeds1.dtype)
        ], dim=1)

    if prompt_embeds2.shape[1] != 77:
        prompt_embeds2 = torch.cat([
            prompt_embeds2, 
            torch.zeros(1, 77 - prompt_embeds2.shape[1], prompt_embeds2.shape[-1], 
                        device=prompt_embeds2.device, dtype=prompt_embeds2.dtype)
        ], dim=1)

    # Concatenate embeddings from both text encoders
    prompt_embeds = torch.cat([prompt_embeds1, prompt_embeds2], dim=-1).squeeze(0)

    # Apply weights to embeddings
    for j in range(len(weight_tensor)):
        if weight_tensor[j] != 1.0:
            prompt_embeds[j] = prompt_embeds[j] * weight_tensor[j]

    return prompt_embeds.unsqueeze(0), pooled_prompt_embeds

def get_weighted_text_embeddings_sdxl(
    pipe: StableDiffusionXLPipeline
    , prompt : str      = ""
    , neg_prompt: str   = ""
    , pad_last_block    = True
):
    """
    This function can process long prompt with weights, no length limitation 
    for Stable Diffusion XL

    Args:
        pipe (StableDiffusionPipeline)
        prompt (str)
        neg_prompt (str)
    Returns:
        prompt_embeds (torch.Tensor)
        neg_prompt_embeds (torch.Tensor)

    Example:
        from diffusers import StableDiffusionPipeline
        text2img_pipe = StableDiffusionPipeline.from_pretrained(
            "stablediffusionapi/deliberate-v2"
            , torch_dtype = torch.float16
            , safety_checker = None
        ).to("cuda:0")
        prompt_embeds, neg_prompt_embeds = get_weighted_text_embeddings_v15(
            pipe = text2img_pipe
            , prompt = "a (white) cat" 
            , neg_prompt = "blur"
        )
        image = text2img_pipe(
            prompt_embeds = prompt_embeds
            , negative_prompt_embeds = neg_prompt_embeds
            , generator = torch.Generator(text2img_pipe.device).manual_seed(2)
        ).images[0]
    """
    eos = pipe.tokenizer.eos_token_id 

    # tokenizer 1
    prompt_tokens, prompt_weights = get_prompts_tokens_with_weights(
        pipe.tokenizer, prompt
    )

    neg_prompt_tokens, neg_prompt_weights = get_prompts_tokens_with_weights(
        pipe.tokenizer, neg_prompt
    )

    # tokenizer 2
    prompt_tokens_2, prompt_weights_2 = get_prompts_tokens_with_weights(
        pipe.tokenizer_2, prompt
    )

    neg_prompt_tokens_2, neg_prompt_weights_2 = get_prompts_tokens_with_weights(
        pipe.tokenizer_2, neg_prompt
    )

    # padding the shorter one
    max_len = max(len(prompt_tokens), len(neg_prompt_tokens), len(prompt_tokens_2), len(neg_prompt_tokens_2))

    prompt_tokens = pad_to_length(prompt_tokens, max_len, eos)
    neg_prompt_tokens = pad_to_length(neg_prompt_tokens, max_len, eos)
    prompt_tokens_2 = pad_to_length(prompt_tokens_2, max_len, eos)
    neg_prompt_tokens_2 = pad_to_length(neg_prompt_tokens_2, max_len, eos)

    prompt_weights = pad_to_length(prompt_weights, max_len, 1.0)
    neg_prompt_weights = pad_to_length(neg_prompt_weights, max_len, 1.0)
    prompt_weights_2 = pad_to_length(prompt_weights_2, max_len, 1.0)
    neg_prompt_weights_2 = pad_to_length(neg_prompt_weights_2, max_len, 1.0)

    embeds = []
    neg_embeds = []

    prompt_token_groups, prompt_weight_groups = group_tokens_and_weights(
        prompt_tokens.copy()
        , prompt_weights.copy()
        , pad_last_block = pad_last_block
    )

    neg_prompt_token_groups, neg_prompt_weight_groups = group_tokens_and_weights(
        neg_prompt_tokens.copy()
        , neg_prompt_weights.copy()
        , pad_last_block = pad_last_block
    )

    prompt_token_groups_2, prompt_weight_groups_2 = group_tokens_and_weights(
        prompt_tokens_2.copy()
        , prompt_weights_2.copy()
        , pad_last_block = pad_last_block
    )

    neg_prompt_token_groups_2, neg_prompt_weight_groups_2 = group_tokens_and_weights(
        neg_prompt_tokens_2.copy()
        , neg_prompt_weights_2.copy()
        , pad_last_block = pad_last_block
    )

    # get prompt embeddings one by one is not working. 
    for i in range(len(prompt_token_groups)):
        # get positive prompt embeddings with weights
        prompt_embeds, pooled_prompt_embeds = get_prompt_embeddings(
            pipe, prompt_token_groups[i], prompt_token_groups_2[i], prompt_weight_groups[i]
        )
        embeds.append(prompt_embeds)

        # Get embeddings for negative prompt with weights
        neg_prompt_embeds, negative_pooled_prompt_embeds = get_prompt_embeddings(
            pipe, neg_prompt_token_groups[i], neg_prompt_token_groups_2[i], neg_prompt_weight_groups[i]
        )
        neg_embeds.append(neg_prompt_embeds)

    prompt_embeds           = torch.cat(embeds, dim = 1)
    negative_prompt_embeds  = torch.cat(neg_embeds, dim = 1)

    return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds

What do you think?

xhinker commented 4 months ago

Provided an additional logic to handle empty and none prompt. https://github.com/xhinker/sd_embed/pull/6

You may also consider adding logics to check the user's prompt before generating the embeddings.