kohya-ss / sd-scripts

Apache License 2.0
4.74k stars 801 forks source link

How do SD's text_tokenizer and Unet work when the input prompt is too long? #1211

Open xddun opened 4 months ago

xddun commented 4 months ago

Question: How do SD's text_tokenizer and Unet work when the input prompt is too long?

Description: Hello, esteemed expert! I have a question recently. When I use AUTOMATIC1111/stable-diffusion-webui, I found that I can input prompts longer than 77 characters, and these prompts' texts are valid for generating images. I don't understand how it works. For example:

prompt = "a photograph of an astronaut riding a horse"
text_input_ids = text_tokenizer(
    prompt,
    padding="max_length",
    max_length=77,
    truncation=True,
    return_tensors="pt"
).input_ids
text_embeddings = text_encoder(text_input_ids.to("cuda"))[0]

The output is torch.Size([1, 77, 768]). I don't understand how text_tokenizer supports such long text prompt inputs or how these excessively long text prompts work in the Unet's cross-attention. I have looked at the code in your repository, but I still haven't found the answer. Forgive my ignorance, and I humbly ask for your guidance.

BootsofLagrangian commented 4 months ago

First, U-Net can consume batch of output of text-encoder like [n, 77, 768]. So, training scripts utilize this property to extend length of tokens 75, 150, 225, and so on.

Why not 77 is the first id of input is id of begin token, <bos>, and the last token if of input is id of end token, <eos>, Therefore, scripts focuses the middle useful tokens, pure text.

Utilizing both properties into one magic thing, for example token length 225, [3, 77, 768] is now input of U-Net and pseudo-output of tokenizer. Let [3, 77, 768] be shape of input. i.e. input.shape = (3, 77, 768)

Then, input[0] = [<bos> + first 75 token of prompt + <eos>] input[1] = [<bos> + second 75 token of prompt + <eos>] input[2] = [<bos> + third 75 token of prompt + <eos>]

By doing this manipulation, A1111 or sd-scripts can receive more than 75 token length.