huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
24.27k stars 5.01k forks source link

[BUG] convert T5 text encoder to float16 results corrupted image #8604

Open Luciennnnnnn opened 1 month ago

Luciennnnnnn commented 1 month ago

Describe the bug

I have tested PixArt-Sigma with following code, where I load text_encoder separately since I will fine-tune it in later. I found T5EncoderModel.from_pretrained(torch_dtype=torch.float16) is very different from T5EncoderModel.from_pretrained().to(dtype=torch.float16), the later one produces corrupted images.

What's happening when we pass torch_dtype argument to from_pretrained?

Reproduction

from diffusers import PixArtSigmaPipeline
import torch

from transformers import T5EncoderModel

# text_encoder = T5EncoderModel.from_pretrained("PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", subfolder="text_encoder", torch_dtype=torch.float16) # good result
text_encoder = T5EncoderModel.from_pretrained("PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", subfolder="text_encoder").to(dtype=torch.float16) # noise

pipe = PixArtSigmaPipeline.from_pretrained(
    "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS",
    text_encoder=text_encoder,
    torch_dtype=torch.float16
)

pipe = pipe.to("cuda")

prompts = "a space elevator, cinematic scifi art"

for idx, prompt in enumerate(prompts):
    image = pipe(prompt=prompt, num_inference_steps=50, generator=torch.manual_seed(1)).images[0]
    image.save("x.png")

Logs

No response

System Info

Who can help?

@sayakpaul @yiyixuxu

sayakpaul commented 1 month ago

Can you check if the outputs of the text encoders vary when loaded using the method you described?

That will be an easier way to reproduce the problem.

Cc: @lawrence-cj

asomoza commented 1 month ago

Interesting, I can reproduce this error, these are the outputs:

# text_encoder = T5EncoderModel.from_pretrained(...).to(dtype=torch.float16)
tensor([[[ 0.0872, -0.0144, -0.0733,  ...,  0.0432,  0.0251,  0.1550],
         [ 0.0277, -0.1429, -0.1173,  ...,  0.0565, -0.1959,  0.0936],
         [-0.0569,  0.1390, -0.1050,  ...,  0.0665,  0.0408,  0.1098]]],
       device='cuda:0', dtype=torch.float16, grad_fn=<MulBackward0>)
# text_encoder = T5EncoderModel.from_pretrained(..., torch_dtype=torch.float16)
tensor([[[-1.2744e-01, -1.4755e-02, -6.3416e-02,  ...,  1.0626e-01,
          -3.7567e-02, -1.1975e-01],
         [-1.1462e-01,  6.1569e-03,  1.1475e-01,  ..., -3.8208e-02,
          -1.1078e-01, -1.0980e-01],
         [-5.2605e-03, -7.7438e-03,  3.5763e-06,  ..., -3.6888e-03,
           7.2136e-03,  2.2907e-03]]], device='cuda:0', dtype=torch.float16,
       grad_fn=<MulBackward0>)

I found that with the the second one, some layers are still torch.float32.

sayakpaul commented 1 month ago

Ccing a Transformers maintainer here: @ArthurZucker

yiyixuxu commented 1 month ago

I think for t5, certain layers are upcasted to float32 when load the checkpoint with from_pretrained in fp16 https://github.com/huggingface/transformers/blob/9ba9369a2557e53a01378199a9839ec6e82d8bc7/src/transformers/models/t5/modeling_t5.py#L797

Luciennnnnnn commented 1 month ago

If certain layers need to be upcasted to float32, is the training code of SD3 correct? In the training code of SD3, the T5 text encoder is initially loaded in float32 and then converted to float16 using the to() method when employing mixed precision training with fp16. We do not appear to encounter similar issues when loading the T5 text encoder in SD3. Could this be due to differences between the T5 encoder utilized in PixArt-Sigma and the one in SD3?

sayakpaul commented 1 month ago

Training does not seem to be affected by this :/

Luciennnnnnn commented 1 month ago

Training does not seem to be affected by this :/

Why? If some parameters of T5 have to be in float32, it will cause flow transformer get inferior text features

sayakpaul commented 1 month ago

Could very well be but the qualitative samples haven’t told me that yet.

This needs a deeper investigation. But the problem could stem from the fact that the original checkpoints are in float16 and I am not exactly sure about the consequences of any kind of casting here yet.

yiyixuxu commented 1 month ago

@Luciennnnnnn, can you run the same experiment for sd3 to see if you also see if it also produces a worse image in fp16? https://github.com/huggingface/diffusers/issues/8604#issue-2357443101

t5 embeddings is used differently in sd3 and pixart so it is possible it has less or no effect in sd3. But we were not aware that these layers in t5 need to be in fp32 before, so it's not impossible the training could work better for sd3 if we do that.

yiyixuxu commented 1 month ago

a quick test for sd3 here - fp16 (bottom row) seems ok?

import torch
from diffusers import StableDiffusion3Pipeline
from transformers import T5EncoderModel

repo = "stabilityai/stable-diffusion-3-medium-diffusers"
dtype = torch.float16

pipe = StableDiffusion3Pipeline.from_pretrained(repo, torch_dtype=dtype)
pipe.enable_model_cpu_offload()
print(pipe.text_encoder_3.encoder.block[11].layer[1].DenseReluDense.wo.weight.dtype)
out = []
generator = torch.Generator(device="cpu").manual_seed(0)
for i in range(2):
    image = pipe(
        "A cat holding a sign that says hello world",
        negative_prompt="",
        num_inference_steps=28,
        guidance_scale=7.0,
        generator=generator,
    ).images[0]
    out.append(image)
pipe.text_encoder_3 = pipe.text_encoder_3.to(dtype)
print(pipe.text_encoder_3.encoder.block[11].layer[1].DenseReluDense.wo.weight.dtype)
generator = torch.Generator(device="cpu").manual_seed(0)
for i in range(2):
    image = pipe(
        "A cat holding a sign that says hello world",
        negative_prompt="",
        num_inference_steps=28,
        guidance_scale=7.0,
        generator=generator,
    ).images[0]
    out.append(image)

from diffusers.utils import make_image_grid
make_image_grid(out, rows=2, cols=2).save("yiyi_test_1_out.png")

yiyi_test_1_out

sayakpaul commented 3 weeks ago

How do we want to go about this? Should we maybe document this issue to start with? Gently ping @asomoza and @yiyixuxu about this.

asomoza commented 3 weeks ago

I did also a test with SD3, it has the same problem but the clip text encoders save the generation.

Also did the same test with the original T5 and this also happens if I use it for generating embedddings but when I do inference it works with both methods.

Probably the best solution is to make that the embeddings stay the same when we do T5EncoderModel.from_pretrained(...).to(dtype=torch.float16) but that's something that has to be done in the transformers side.

Since that could take more time, I also think that we should add to the docs that for the T5 Text Encoders, users can't do T5EncoderModel.from_pretrained(...).to(dtype=torch.float16)

yiyixuxu commented 3 weeks ago

yes I should we should document about this :)