huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
25.23k stars 5.22k forks source link

excessive graph breaks on `attention.py` and `attention_processor.py` for control_net on `torch.compile` #3218

Closed shingjan closed 1 year ago

shingjan commented 1 year ago

Describe the bug

I tried to run the controlnet example from this blog post and it turned out that the BasicTransformerBlock is causing a large number of graph breaks (>100) on a single controlnet pipeline. Ideally the whole BasicTransformerBlock.forward should be include in one single frame for speedups. The exact reason for the graph breaks is:

call_function UserDefinedObjectVariable(AttnProcessor2_0) [NNModuleVariable(), TensorVariable()] {'encoder_hidden_states': TensorVariable(), 'attention_mask': ConstantVariable(NoneType)}

for both self attention and cross attention. Is there a way to reduce the graph breaks to make StableDiffusionControlNetPipeline working better with torch.compile?

Reproduction

from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
from diffusers.utils import load_image
import cv2
from PIL import Image
import torch
import numpy as np

image = load_image(
    "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png"
)

image = np.array(image)

low_threshold = 100
high_threshold = 200

image = cv2.Canny(image, low_threshold, high_threshold)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
canny_image = Image.fromarray(image)

controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()

import torch
import torch._dynamo as dynamo

@dynamo.optimize("inductor")
def generate(prompt):
    generator = [torch.Generator(device="cuda").manual_seed(2) for i in range(len(prompt))]
    return pipe(
        prompt,
        canny_image,
        negative_prompt=["monochrome, lowres, bad anatomy, worst quality, low quality"] * len(prompt),
        num_inference_steps=10,
        generator=generator,
    )

prompt = ", best quality, extremely detailed"
prompt = [t + prompt for t in ["Sandra Oh", "Kim Kardashian", "rihanna", "taylor swift"]]
ex = dynamo.explain(generate, prompt)[-1]
print(ex)

### Logs

```shell
graph #169 break reason: call_function UserDefinedObjectVariable(AttnProcessor2_0) [NNModuleVariable(), TensorVariable()] {'encoder_hidden_states': ConstantVariable(NoneType), 'attention_mask': ConstantVariable(NoneType)} after 3
stack:   File "/home/yj/diffusers/src/diffusers/models/attention.py", line 313, in forward
    attn_output = self.attn1(
  File "/home/yj/pytorch/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/yj/diffusers/src/diffusers/models/attention_processor.py", line 267, in forward
    return self.processor(

graph #171 break reason: call_function UserDefinedObjectVariable(AttnProcessor2_0) [NNModuleVariable(), TensorVariable()] {'encoder_hidden_states': TensorVariable(), 'attention_mask': ConstantVariable(NoneType)} after 1
stack:   File "/home/yj/diffusers/src/diffusers/models/attention.py", line 331, in <resume in forward>
    attn_output = self.attn2(
  File "/home/yj/pytorch/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/yj/diffusers/src/diffusers/models/attention_processor.py", line 267, in forward
    return self.processor(

System Info

Ubuntu 20.04 with cuda 11.8

diffusers 0.16.0.dev0 /home/yj/diffusers torch 2.1.0a0+git0bbf8a9 /home/yj/pytorch

sayakpaul commented 1 year ago

Cc: @pcuenca

patrickvonplaten commented 1 year ago

@shingjan, we advise to only optimize the unet part of the pipeline with torch inductor could you instead try:

pipe.unet = torch.compile(pipe.unet, backend='inductor')

Also see: https://huggingface.co/docs/diffusers/optimization/torch2.0#using-accelerated-transformers-and-torchcompile

shingjan commented 1 year ago

@patrickvonplaten thanks for the response! I think AttnProcessor2_0 is heavily used in unet so even if only unet is decorated, the graph breaks persist.

patrickvonplaten commented 1 year ago

I don't fully understand this, what exactly is the issue here? Can we reproduce it somehow?

patrickvonplaten commented 1 year ago

@shingjan,

This might help solve it actually: https://github.com/huggingface/diffusers/pull/3286

github-actions[bot] commented 1 year ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

shingjan commented 1 year ago

@patrickvonplaten Sorry for the late reply. Yes I did a rebase and most of the graph breaks seen on diffusers==0.16.1 is gone. The maybe_allow_in_graph and is_compiled are very useful. Closed this one as fixed by #3286