huggingface / diffusers

๐Ÿค— Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
26.21k stars 5.4k forks source link

`torch.compile` doesn't seem to be working for text-to-video pipelines #3915

Closed apolinario closed 1 year ago

apolinario commented 1 year ago

Describe the bug

Trying to use torch.compile on a text-to-video model doesn't work

If I try to follow the docs and do a pipe.unet.to(memory_format=torch.channels_last)

from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
pipe = DiffusionPipeline.from_pretrained("cerspense/zeroscope_v2_576w", torch_dtype=torch.float16)
pipe.unet.to(memory_format=torch.channels_last)

I get a

RuntimeError: required rank 4 tensor to use channels_last format 

If I try to not use the torch.channels_last format and go directly

from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
pipe = DiffusionPipeline.from_pretrained("cerspense/zeroscope_v2_576w", torch_dtype=torch.float16)
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)

I get a

RuntimeError: "LayerNormKernelImpl" not implemented for 'Half'
Keyboard interruption in main thread... closing server. 

Reproduction

from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
pipe = DiffusionPipeline.from_pretrained("cerspense/zeroscope_v2_576w", torch_dtype=torch.float16)
pipe.unet.to(memory_format=torch.channels_last)
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)

Logs

Full traceback for pipe.unet.to(memory_format=torch.channels_last)

โ•ญโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ Traceback (most recent call last) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ•ฎ
โ”‚ in <cell line: 1>:1                                                                              โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1145 in to                    โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   1142 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   non_blocking, memory_format=convert_to_format)                โ”‚
โ”‚   1143 โ”‚   โ”‚   โ”‚   return t.to(device, dtype if t.is_floating_point() or t.is_complex() else No  โ”‚
โ”‚   1144 โ”‚   โ”‚                                                                                     โ”‚
โ”‚ โฑ 1145 โ”‚   โ”‚   return self._apply(convert)                                                       โ”‚
โ”‚   1146 โ”‚                                                                                         โ”‚
โ”‚   1147 โ”‚   def register_full_backward_pre_hook(                                                  โ”‚
โ”‚   1148 โ”‚   โ”‚   self,                                                                             โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:797 in _apply                 โ”‚
โ”‚                                                                                                  โ”‚
โ”‚    794 โ”‚                                                                                         โ”‚
โ”‚    795 โ”‚   def _apply(self, fn):                                                                 โ”‚
โ”‚    796 โ”‚   โ”‚   for module in self.children():                                                    โ”‚
โ”‚ โฑ  797 โ”‚   โ”‚   โ”‚   module._apply(fn)                                                             โ”‚
โ”‚    798 โ”‚   โ”‚                                                                                     โ”‚
โ”‚    799 โ”‚   โ”‚   def compute_should_use_set_data(tensor, tensor_applied):                          โ”‚
โ”‚    800 โ”‚   โ”‚   โ”‚   if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):           โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:797 in _apply                 โ”‚
โ”‚                                                                                                  โ”‚
โ”‚    794 โ”‚                                                                                         โ”‚
โ”‚    795 โ”‚   def _apply(self, fn):                                                                 โ”‚
โ”‚    796 โ”‚   โ”‚   for module in self.children():                                                    โ”‚
โ”‚ โฑ  797 โ”‚   โ”‚   โ”‚   module._apply(fn)                                                             โ”‚
โ”‚    798 โ”‚   โ”‚                                                                                     โ”‚
โ”‚    799 โ”‚   โ”‚   def compute_should_use_set_data(tensor, tensor_applied):                          โ”‚
โ”‚    800 โ”‚   โ”‚   โ”‚   if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):           โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:797 in _apply                 โ”‚
โ”‚                                                                                                  โ”‚
โ”‚    794 โ”‚                                                                                         โ”‚
โ”‚    795 โ”‚   def _apply(self, fn):                                                                 โ”‚
โ”‚    796 โ”‚   โ”‚   for module in self.children():                                                    โ”‚
โ”‚ โฑ  797 โ”‚   โ”‚   โ”‚   module._apply(fn)                                                             โ”‚
โ”‚    798 โ”‚   โ”‚                                                                                     โ”‚
โ”‚    799 โ”‚   โ”‚   def compute_should_use_set_data(tensor, tensor_applied):                          โ”‚
โ”‚    800 โ”‚   โ”‚   โ”‚   if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):           โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:797 in _apply                 โ”‚
โ”‚                                                                                                  โ”‚
โ”‚    794 โ”‚                                                                                         โ”‚
โ”‚    795 โ”‚   def _apply(self, fn):                                                                 โ”‚
โ”‚    796 โ”‚   โ”‚   for module in self.children():                                                    โ”‚
โ”‚ โฑ  797 โ”‚   โ”‚   โ”‚   module._apply(fn)                                                             โ”‚
โ”‚    798 โ”‚   โ”‚                                                                                     โ”‚
โ”‚    799 โ”‚   โ”‚   def compute_should_use_set_data(tensor, tensor_applied):                          โ”‚
โ”‚    800 โ”‚   โ”‚   โ”‚   if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):           โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:797 in _apply                 โ”‚
โ”‚                                                                                                  โ”‚
โ”‚    794 โ”‚                                                                                         โ”‚
โ”‚    795 โ”‚   def _apply(self, fn):                                                                 โ”‚
โ”‚    796 โ”‚   โ”‚   for module in self.children():                                                    โ”‚
โ”‚ โฑ  797 โ”‚   โ”‚   โ”‚   module._apply(fn)                                                             โ”‚
โ”‚    798 โ”‚   โ”‚                                                                                     โ”‚
โ”‚    799 โ”‚   โ”‚   def compute_should_use_set_data(tensor, tensor_applied):                          โ”‚
โ”‚    800 โ”‚   โ”‚   โ”‚   if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):           โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:797 in _apply                 โ”‚
โ”‚                                                                                                  โ”‚
โ”‚    794 โ”‚                                                                                         โ”‚
โ”‚    795 โ”‚   def _apply(self, fn):                                                                 โ”‚
โ”‚    796 โ”‚   โ”‚   for module in self.children():                                                    โ”‚
โ”‚ โฑ  797 โ”‚   โ”‚   โ”‚   module._apply(fn)                                                             โ”‚
โ”‚    798 โ”‚   โ”‚                                                                                     โ”‚
โ”‚    799 โ”‚   โ”‚   def compute_should_use_set_data(tensor, tensor_applied):                          โ”‚
โ”‚    800 โ”‚   โ”‚   โ”‚   if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):           โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:797 in _apply                 โ”‚
โ”‚                                                                                                  โ”‚
โ”‚    794 โ”‚                                                                                         โ”‚
โ”‚    795 โ”‚   def _apply(self, fn):                                                                 โ”‚
โ”‚    796 โ”‚   โ”‚   for module in self.children():                                                    โ”‚
โ”‚ โฑ  797 โ”‚   โ”‚   โ”‚   module._apply(fn)                                                             โ”‚
โ”‚    798 โ”‚   โ”‚                                                                                     โ”‚
โ”‚    799 โ”‚   โ”‚   def compute_should_use_set_data(tensor, tensor_applied):                          โ”‚
โ”‚    800 โ”‚   โ”‚   โ”‚   if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):           โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:820 in _apply                 โ”‚
โ”‚                                                                                                  โ”‚
โ”‚    817 โ”‚   โ”‚   โ”‚   # track autograd history of `param_applied`, so we have to use                โ”‚
โ”‚    818 โ”‚   โ”‚   โ”‚   # `with torch.no_grad():`                                                     โ”‚
โ”‚    819 โ”‚   โ”‚   โ”‚   with torch.no_grad():                                                         โ”‚
โ”‚ โฑ  820 โ”‚   โ”‚   โ”‚   โ”‚   param_applied = fn(param)                                                 โ”‚
โ”‚    821 โ”‚   โ”‚   โ”‚   should_use_set_data = compute_should_use_set_data(param, param_applied)       โ”‚
โ”‚    822 โ”‚   โ”‚   โ”‚   if should_use_set_data:                                                       โ”‚
โ”‚    823 โ”‚   โ”‚   โ”‚   โ”‚   param.data = param_applied                                                โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1141 in convert               โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   1138 โ”‚   โ”‚                                                                                     โ”‚
โ”‚   1139 โ”‚   โ”‚   def convert(t):                                                                   โ”‚
โ”‚   1140 โ”‚   โ”‚   โ”‚   if convert_to_format is not None and t.dim() in (4, 5):                       โ”‚
โ”‚ โฑ 1141 โ”‚   โ”‚   โ”‚   โ”‚   return t.to(device, dtype if t.is_floating_point() or t.is_complex() els  โ”‚
โ”‚   1142 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   non_blocking, memory_format=convert_to_format)                โ”‚
โ”‚   1143 โ”‚   โ”‚   โ”‚   return t.to(device, dtype if t.is_floating_point() or t.is_complex() else No  โ”‚
โ”‚   1144                                                                                           โ”‚
โ•ฐโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ•ฏ
RuntimeError: required rank 4 tensor to use channels_last format 

Full traceback for pipe.unet.to(memory_format=torch.channels_last)

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/gradio/routes.py", line 437, in run_predict
    output = await app.get_blocks().process_api(
  File "/usr/local/lib/python3.10/dist-packages/gradio/blocks.py", line 1352, in process_api
    result = await self.call_function(
  File "/usr/local/lib/python3.10/dist-packages/gradio/blocks.py", line 1077, in call_function
    prediction = await anyio.to_thread.run_sync(
  File "/usr/local/lib/python3.10/dist-packages/anyio/to_thread.py", line 33, in run_sync
    return await get_asynclib().run_sync_in_worker_thread(
  File "/usr/local/lib/python3.10/dist-packages/anyio/_backends/_asyncio.py", line 877, in run_sync_in_worker_thread
    return await future
  File "/usr/local/lib/python3.10/dist-packages/anyio/_backends/_asyncio.py", line 807, in run
    result = context.run(func, *args)
  File "<ipython-input-13-947ecc021452>", line 13, in infer
    video_frames = pipe(prompt, num_inference_steps=40, height=320, width=576, num_frames=24).frames
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py", line 605, in __call__
    prompt_embeds = self._encode_prompt(
  File "/usr/local/lib/python3.10/dist-packages/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py", line 298, in _encode_prompt
    prompt_embeds = self.text_encoder(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/clip/modeling_clip.py", line 822, in forward
    return self.text_model(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/clip/modeling_clip.py", line 740, in forward
    encoder_outputs = self.encoder(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/clip/modeling_clip.py", line 654, in forward
    layer_outputs = encoder_layer(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/clip/modeling_clip.py", line 382, in forward
    hidden_states = self.layer_norm1(hidden_states)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/normalization.py", line 190, in forward
    return F.layer_norm(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/functional.py", line 2515, in layer_norm
    return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled)
RuntimeError: "LayerNormKernelImpl" not implemented for 'Half'
Keyboard interruption in main thread... closing server. 

System Info

diffusers==0.17.1

Who can help?

@patrickvonplaten

patrickvonplaten commented 1 year ago

Think the above error was because of a missing .to("cuda") statement. Note that torch compile only works on CUDA.

But if I add a .to("cuda") statement I get a new error:

    getattr(self, inst.opname)(inst)
  File "/home/patrick/hf/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 342, in wrapper
    return inner_fn(self, inst)
  File "/home/patrick/hf/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1014, in CALL_FUNCTION_KW
    self.call_function(fn, args, kwargs)
  File "/home/patrick/hf/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 474, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/patrick/hf/lib/python3.10/site-packages/torch/_dynamo/variables/misc.py", line 744, in call_function
    return self.obj.call_method(tx, self.name, args, kwargs).add_options(self)
  File "/home/patrick/hf/lib/python3.10/site-packages/torch/_dynamo/variables/tensor.py", line 424, in call_method
    return wrap_fx_proxy(
  File "/home/patrick/hf/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 754, in wrap_fx_proxy
    return wrap_fx_proxy_cls(
  File "/home/patrick/hf/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 789, in wrap_fx_proxy_cls
    example_value = get_fake_value(proxy.node, tx)
  File "/home/patrick/hf/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1168, in get_fake_value
    unimplemented(f"dynamic shape operator: {cause.func}")
  File "/home/patrick/hf/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 71, in unimplemented
    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: dynamic shape operator: aten.repeat_interleave.Tensor

from user code:
   File "/home/patrick/python_bin/diffusers/models/unet_3d_condition.py", line 521, in forward
    emb = emb.repeat_interleave(repeats=num_frames, dim=0)

Set torch._dynamo.config.verbose=True for more information

You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True

which can be reproduced when running:

import torch
from diffusers import DiffusionPipeline
from diffusers.utils import export_to_video
from PIL import Image

pipe = DiffusionPipeline.from_pretrained("cerspense/zeroscope_v2_576w", torch_dtype=torch.float16)
pipe.to("cuda")
pipe.enable_vae_slicing()

pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)

prompt = "Darth Vader is surfing on waves"
video_frames = pipe(prompt, num_inference_steps=40, height=320, width=576, num_frames=36).frames
video_path = export_to_video(video_frames, output_video_path="/home/patrick/videos/video_576_darth_vader_36.mp4")

I'm currently a bit busy with other things @sayakpaul do you have some time to look into it by any chance?

sayakpaul commented 1 year ago

There seems to be an existing problem with repeat_interleave() which might have been fixed in the nightlies. Currently trying that out.

sayakpaul commented 1 year ago

@patrickvonplaten let's jam here: https://github.com/huggingface/diffusers/pull/3949.

hnnam0906 commented 6 months ago

Hi! I try to use the torch.compile for the model "damo-vilab/text-to-video-ms-1.7b" (https://huggingface.co/docs/diffusers/api/pipelines/text_to_video) but it takes very long time to generate. Testing on 3090, to gen a video with num_inference_steps=25, it took me about 10s without torch.compile but more than 100s if i use torch.compile. What is the possible issue and can you support to fix it? Thanks

import torch
from diffusers import DiffusionPipeline
from diffusers.utils import export_to_video

pipe = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16")
**pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)**
pipe.enable_model_cpu_offload()

pipe.enable_vae_slicing()

prompt = "Darth Vader surfing a wave"
video_frames = pipe(prompt, num_inference_steps=25).frames
video_path = export_to_video(video_frames)
video_path
tolgacangoz commented 6 months ago

Hi @hnnam0906! The network is compiled at the first inference. Don't count first inference, evaluate next inferences. See here.

hnnam0906 commented 6 months ago

@standardAI : Thank for your info.