Closed thangnd-zenai closed 4 months ago
This is my solution:
group_tokens_and_weights
function:
if len(token_ids) > 0:
Modify to:
if len(token_ids) >= 0:
Other functions:
def pad_to_length(tokens, length, token_id):
if len(tokens) < length:
tokens += [token_id] * (length - len(tokens))
return tokens
def get_embeds_from_encoder(encoder, token_tensor):
output = encoder(token_tensor, output_hidden_states = True)
return output.hidden_states[-2], output[0]
def get_prompt_embeddings(pipe, tokens1, tokens2, weights):
token_tensor1 = torch.tensor([tokens1], dtype=torch.long, device=pipe.device)
token_tensor2 = torch.tensor([tokens2], dtype=torch.long, device=pipe.device)
weight_tensor = torch.tensor(weights, dtype=torch.float16, device=pipe.device)
# Get embeddings from text_encoder and text_encoder_2
prompt_embeds1, _ = get_embeds_from_encoder(pipe.text_encoder, token_tensor1)
prompt_embeds2, pooled_prompt_embeds = get_embeds_from_encoder(pipe.text_encoder_2, token_tensor2)
if prompt_embeds1.shape[1] != 77:
prompt_embeds1 = torch.cat([
prompt_embeds1,
torch.zeros(1, 77 - prompt_embeds1.shape[1], prompt_embeds1.shape[-1],
device=prompt_embeds1.device, dtype=prompt_embeds1.dtype)
], dim=1)
if prompt_embeds2.shape[1] != 77:
prompt_embeds2 = torch.cat([
prompt_embeds2,
torch.zeros(1, 77 - prompt_embeds2.shape[1], prompt_embeds2.shape[-1],
device=prompt_embeds2.device, dtype=prompt_embeds2.dtype)
], dim=1)
# Concatenate embeddings from both text encoders
prompt_embeds = torch.cat([prompt_embeds1, prompt_embeds2], dim=-1).squeeze(0)
# Apply weights to embeddings
for j in range(len(weight_tensor)):
if weight_tensor[j] != 1.0:
prompt_embeds[j] = prompt_embeds[j] * weight_tensor[j]
return prompt_embeds.unsqueeze(0), pooled_prompt_embeds
def get_weighted_text_embeddings_sdxl(
pipe: StableDiffusionXLPipeline
, prompt : str = ""
, neg_prompt: str = ""
, pad_last_block = True
):
"""
This function can process long prompt with weights, no length limitation
for Stable Diffusion XL
Args:
pipe (StableDiffusionPipeline)
prompt (str)
neg_prompt (str)
Returns:
prompt_embeds (torch.Tensor)
neg_prompt_embeds (torch.Tensor)
Example:
from diffusers import StableDiffusionPipeline
text2img_pipe = StableDiffusionPipeline.from_pretrained(
"stablediffusionapi/deliberate-v2"
, torch_dtype = torch.float16
, safety_checker = None
).to("cuda:0")
prompt_embeds, neg_prompt_embeds = get_weighted_text_embeddings_v15(
pipe = text2img_pipe
, prompt = "a (white) cat"
, neg_prompt = "blur"
)
image = text2img_pipe(
prompt_embeds = prompt_embeds
, negative_prompt_embeds = neg_prompt_embeds
, generator = torch.Generator(text2img_pipe.device).manual_seed(2)
).images[0]
"""
eos = pipe.tokenizer.eos_token_id
# tokenizer 1
prompt_tokens, prompt_weights = get_prompts_tokens_with_weights(
pipe.tokenizer, prompt
)
neg_prompt_tokens, neg_prompt_weights = get_prompts_tokens_with_weights(
pipe.tokenizer, neg_prompt
)
# tokenizer 2
prompt_tokens_2, prompt_weights_2 = get_prompts_tokens_with_weights(
pipe.tokenizer_2, prompt
)
neg_prompt_tokens_2, neg_prompt_weights_2 = get_prompts_tokens_with_weights(
pipe.tokenizer_2, neg_prompt
)
# padding the shorter one
max_len = max(len(prompt_tokens), len(neg_prompt_tokens), len(prompt_tokens_2), len(neg_prompt_tokens_2))
prompt_tokens = pad_to_length(prompt_tokens, max_len, eos)
neg_prompt_tokens = pad_to_length(neg_prompt_tokens, max_len, eos)
prompt_tokens_2 = pad_to_length(prompt_tokens_2, max_len, eos)
neg_prompt_tokens_2 = pad_to_length(neg_prompt_tokens_2, max_len, eos)
prompt_weights = pad_to_length(prompt_weights, max_len, 1.0)
neg_prompt_weights = pad_to_length(neg_prompt_weights, max_len, 1.0)
prompt_weights_2 = pad_to_length(prompt_weights_2, max_len, 1.0)
neg_prompt_weights_2 = pad_to_length(neg_prompt_weights_2, max_len, 1.0)
embeds = []
neg_embeds = []
prompt_token_groups, prompt_weight_groups = group_tokens_and_weights(
prompt_tokens.copy()
, prompt_weights.copy()
, pad_last_block = pad_last_block
)
neg_prompt_token_groups, neg_prompt_weight_groups = group_tokens_and_weights(
neg_prompt_tokens.copy()
, neg_prompt_weights.copy()
, pad_last_block = pad_last_block
)
prompt_token_groups_2, prompt_weight_groups_2 = group_tokens_and_weights(
prompt_tokens_2.copy()
, prompt_weights_2.copy()
, pad_last_block = pad_last_block
)
neg_prompt_token_groups_2, neg_prompt_weight_groups_2 = group_tokens_and_weights(
neg_prompt_tokens_2.copy()
, neg_prompt_weights_2.copy()
, pad_last_block = pad_last_block
)
# get prompt embeddings one by one is not working.
for i in range(len(prompt_token_groups)):
# get positive prompt embeddings with weights
prompt_embeds, pooled_prompt_embeds = get_prompt_embeddings(
pipe, prompt_token_groups[i], prompt_token_groups_2[i], prompt_weight_groups[i]
)
embeds.append(prompt_embeds)
# Get embeddings for negative prompt with weights
neg_prompt_embeds, negative_pooled_prompt_embeds = get_prompt_embeddings(
pipe, neg_prompt_token_groups[i], neg_prompt_token_groups_2[i], neg_prompt_weight_groups[i]
)
neg_embeds.append(neg_prompt_embeds)
prompt_embeds = torch.cat(embeds, dim = 1)
negative_prompt_embeds = torch.cat(neg_embeds, dim = 1)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
What do you think?
Provided an additional logic to handle empty and none prompt. https://github.com/xhinker/sd_embed/pull/6
You may also consider adding logics to check the user's prompt before generating the embeddings.
Hi @xhinker , I see that your pipeline encounters an error when run with:
Bugs:
How can this be fixed? Additionally, I sometimes encounter a bug in this line:
Bug:
For the second error, I haven't been able to trace exactly when it occurs because it appears when I run the product in production. I am still investigating the reasons for its occurrence. I am writing this issue hoping you can provide a solution to these problems if you have any ideas.