damian0815 / compel

A prompting enhancement library for transformers-type text embedding systems
MIT License
526 stars 47 forks source link

Using fp16 text_encoder seems to cause computational error #84

Closed GongXinyuu closed 4 months ago

GongXinyuu commented 8 months ago

Hi @damian0815 , thanks for your great work! Recently I've tried to incorporate compel in to https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora_sdxl.py

However, I have found there might be an computational error when running compel with text_encoder in torch.float16 dtype. Below is a full code snippet to reproduce it.

from transformers import AutoTokenizer, PretrainedConfig
from compel import Compel, ReturnedEmbeddingsType
import torch

pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0"

# import correct text encoder classes
text_encoder_cls_one = import_model_class_from_model_name_or_path(
    pretrained_model_name_or_path, None
)
text_encoder_cls_two = import_model_class_from_model_name_or_path(
    pretrained_model_name_or_path, None, subfolder="text_encoder_2"
)

text_encoder_one = text_encoder_cls_one.from_pretrained(
    pretrained_model_name_or_path,
    subfolder="text_encoder",
)
text_encoder_two = text_encoder_cls_two.from_pretrained(
    pretrained_model_name_or_path,
    subfolder="text_encoder_2",
)

tokenizer_one = AutoTokenizer.from_pretrained(
    pretrained_model_name_or_path,
    subfolder="tokenizer",
    revision=None,
    use_fast=False,
)
tokenizer_two = AutoTokenizer.from_pretrained(
    pretrained_model_name_or_path,
    subfolder="tokenizer_2",
    revision=None,
    use_fast=False,
)

def tokenize_prompt(tokenizer, prompt, truncation=True):
    text_inputs = tokenizer(
        prompt,
        padding="max_length",
        max_length=tokenizer.model_max_length,
        truncation=truncation,
        return_tensors="pt",
    )
    text_input_ids = text_inputs.input_ids
    return text_input_ids

# borrowed from https://github.com/huggingface/diffusers/blob/5d848ec07c2011d600ce5e5c1aa02a03152aea9b/examples/text_to_image/train_text_to_image_lora_sdxl.py#L468
def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
    prompt_embeds_list = []

    for i, text_encoder in enumerate(text_encoders):
        if tokenizers is not None:
            tokenizer = tokenizers[i]
            text_input_ids = tokenize_prompt(tokenizer, prompt)
        else:
            assert text_input_ids_list is not None
            text_input_ids = text_input_ids_list[i]

        prompt_embeds = text_encoder(
            text_input_ids.to(text_encoder.device),
            output_hidden_states=True,
            return_dict=False,
        )

        # We are only ALWAYS interested in the pooled output of the final text encoder
        pooled_prompt_embeds = prompt_embeds[0]
        prompt_embeds = prompt_embeds[-1][-2]
        bs_embed, seq_len, _ = prompt_embeds.shape
        prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
        prompt_embeds_list.append(prompt_embeds)

    prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)  #
    pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)

    return prompt_embeds, pooled_prompt_embeds

text_encoder_one=text_encoder_one.to('cuda', torch.float16)
text_encoder_two=text_encoder_two.to('cuda', torch.float16)

compel = Compel(
    tokenizer=[tokenizer_one, tokenizer_two],
    text_encoder=[text_encoder_one, text_encoder_two],
    returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
    requires_pooled=[False, True],
    truncate_long_prompts=False,
)

prompts = ['model photo,contrast lace gigot sleeve top,slim fit regular black,plain print,long leg of mutton sleeve,round neck,rib knit,elegant,contrast lace', 'garment photo,young boy slogan panda print contrast raglan sleeve tee,loose regular beige,cartoon,colorblock print,long raglan sleeve,round neck,fabric,casual']

with torch.no_grad():
    prompt_embeds, pooled_prompt_embeds = encode_prompt(
        text_encoders=[text_encoder_one, text_encoder_two],
        tokenizers=[tokenizer_one, tokenizer_two],
        prompt=prompts,
    )
    prompt_embeds_compel, pooled_prompt_embeds_compel = compel(prompts)

torch.testing.assert_close(prompt_embeds, prompt_embeds_compel)

The above code will throw an error:

AssertionError: Tensor-likes are not close!

Mismatched elements: 230024 / 315392 (72.9%)
Greatest absolute difference: 0.125 at index (0, 0, 324) (up to 1e-05 allowed)
Greatest relative difference: inf at index (0, 0, 982) (up to 0.001 allowed)

I found that if

text_encoder_one=text_encoder_one.to('cuda', torch.float16)
text_encoder_two=text_encoder_two.to('cuda', torch.float16)

is removed, then it works perfectly well.

damian0815 commented 7 months ago

could be - i don't have the resources to be able to check at the moment. but are you trying to do training with compel? you probably shouldn't do that.

GongXinyuu commented 7 months ago

could be - i don't have the resources to be able to check at the moment. but are you trying to do training with compel? you probably shouldn't do that.

Yes. I'm trying to incorporate compel into diffusion model lora training. May I know the reason why it shouldn't be used for training? Thanks.

damian0815 commented 6 months ago

@GongXinyuu if you train with weighted captions, you'll possibly produce a model that will only respond to prompts that have been weighted the same way, but more likely a model that is just harder to use

GongXinyuu commented 6 months ago

@GongXinyuu if you train with weighted captions, you'll possibly produce a model that will only respond to prompts that have been weighted the same way, but more likely a model that is just harder to use

Gotcha. The main reason I want to incorporate compel into model training is to mitigate the 77 token limitation brought by CLIP, as I want to use more detailed prompt to train DM just like DALLE3. Guess it should be fine if I don't use the prompt weighting function?

damian0815 commented 6 months ago

i'd suggest limiting the captions to 77 tokens, for similar reasons but mostly because compel is not at all aware of word boundaries and that means that you will end up training your model on broken text encoder data. the >77 token trick is an even uglier and more unreliable hack than the prompt weighting.