huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
25.18k stars 5.21k forks source link

pytorch2.0 rocm torch.compile gives RuntimeWarning and generated black images #2758

Closed ttio2tech closed 1 year ago

ttio2tech commented 1 year ago

I was testing the https://huggingface.co/docs/diffusers/optimization/torch2.0 examples, The normal 'Accelerated Transformers implementation' was successful. But the torch.compile example: pipe.unet = torch.compile(pipe.unet) image = pipe(prompt).images[0] gives warning: pipelines/pipeline_utils.py:1023: RuntimeWarning: invalid value encountered in cast images = (images * 255).round().astype("uint8") Resulting in black images.

sayakpaul commented 1 year ago

Does this happen for all prompts you tried so far?

@patrickvonplaten @pcuenca Cc.

Stax124 commented 1 year ago

This bug happens in the UNet step and it differs from sampler to sampler. It takes input latent in correct format, but throws out tensor of the same shape full of NANs. From my testing, this happens quite often with DPM samplers, while EulerA is quite stable.

sayakpaul commented 1 year ago

Hmm. Could you ensure you're on the latest Torch 2.0 install?

Stax124 commented 1 year ago

Im a VoltaML developer, we are using Stable release of Pytorch 2.0 with CUDA 11.8, in containerized environment, so nothing should interfere.

sayakpaul commented 1 year ago

I will let @pcuenca comment here.

patrickvonplaten commented 1 year ago

@ttio2tech ,

could you add a reproducible code snippet here?

ttio2tech commented 1 year ago

The code that caused the error:

environment: ROCm 5.4.2 Pytorch 2.0

import torch from diffusers import StableDiffusionPipeline pipe = StableDiffusionPipeline.from_pretrained("hakurei/waifu-diffusion", torch_dtype=torch.float16).to( "cuda" ) pipe.unet = torch.compile(pipe.unet) # this step gave error

sayakpaul commented 1 year ago

Does it give warning or error?

If it's an error, it could be because of dependency problems. Could you post the error snippet?

patrickvonplaten commented 1 year ago

Also seeing this problem a bit now - @pcuenca , @patil-suraj did you ever encounter this?

pcuenca commented 1 year ago

I'm confused, the issue title and @ttio2tech mention a ROCm environment, but @Stax124 says they are using CUDA. Does this issue only apply to ROCm/AMD cards, or is this a general problem with torch 2?

I haven't seen it myself, but will test the model mentioned by @ttio2tech (on CUDA, unfortunately I have no AMD cards compatible with ROCm).

onitake commented 1 year ago

It's a bit hard to produce a minimal example due to the various dependencies involved, but I think I've been able to narrow the issue down a bit. I'm running diffusers 0.16.1 with a custom build of PyTorch 2.1 against ROCm 5.5.0 (currently needed for AMD RDNA3 support). The GPU is an RX 7900 XTX. FP32 works fine in all cases, but with FP16 I get the NaN issue when using the model https://huggingface.co/waifu-diffusion/wd-1-5-beta2. With https://huggingface.co/hakurei/waifu-diffusion, FP16 is ok as well.

I've attached a dump of the pipelines and a diff here: https://gist.github.com/onitake/1534ea2eedecb5d346fc70ba2a278d81

What stands out to me is that the working pipeline (WD-1.4) uses CLIPImageProcessor while the non-working one (WD-1.5) uses CLIPFeatureExtractor, and that WD-1.4 sets PNMDScheduler.prediction_type='epsilon', while in WD-1.5 it's PNMDScheduler.prediction_type='v_prediction'.

The code to trigger the error is very basic:

import torch
from diffusers import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained(
    pretrained_model_name_or_path="../wd-1-5-beta2/",
    torch_dtype=torch.float16,
).to('cuda')
generator = torch.Generator("cuda"). manual_seed(0)
result = pipe(prompt="sunflower", generator=generator)
for image in result.images:
    image.save("test.png")

I've tried a few other configurations as well, but the result is the same.

Is that enough to investigate the issue, or is there anything else I can try/provide?

Edit: I tried modifying the scheduler configuration to use prediction_type='epsilon', but that still resulted in NaNs. Edit 2: CLIPFeatureExtractor turned out to be a red herring too. Modifying the pipeline description to use CLIPImageProcessor instead didn't help.

onitake commented 1 year ago

I traced a failed run through to https://github.com/huggingface/diffusers/blob/v0.16.1/src/diffusers/models/attention_processor.py#L743

This will produce some NaN values in the result, and as soon as that happens, they'll quickly propagate through the whole hidden_states tensor at the end of the iteration.

Maybe related: https://pytorch.org/docs/main/notes/numerical_accuracy.html (I tried some those environment vars and flags, but they didn't help).

onitake commented 1 year ago

After experimenting with https://pytorch.org/docs/main/generated/torch.nn.functional.scaled_dot_product_attention.html a bit, I found that my Torch build only includes the C++ implementation. Maybe this is the problem, or maybe the optimized algorithms are hiding the NaNs.

Unfortunately, there doesn't seem to be a ROCm implementation of the FlashAttention algorithm yet. AMD is still working on it: https://github.com/ROCmSoftwarePlatform/flash-attention/pull/1

onitake commented 1 year ago

In any case, I found that I can hide the issue by adding the following line after the F.scaled_dot_product_attention() call:

hidden_states = torch.nan_to_num(hidden_states)

It didn't look like it had much of a performance impact on my machine. The algorithm ran at about the same speed as WD-1.4 without this line.

sayakpaul commented 1 year ago

Thanks for sharing your details insights here. If this issue comes up multiple times, I think we might fix with the following (referring users to this issue thread):

torch.nan_to_num(hidden_states)

But if this behavior could be reproduced minimally with just vanilla PyTorch code on your setup, that'd probably be more helpful for the PyTorch to understand the lower-level reasons behind this.

github-actions[bot] commented 1 year ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

vedmant commented 4 months ago

I have the same issue on M1 with https://huggingface.co/stabilityai/stable-diffusion-2-1 model