huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
25.99k stars 5.35k forks source link

Torch Compile support for runwayml/stable-diffusion-v1-5 VAE Decoder #9091

Closed asfiyab-nvidia closed 1 month ago

asfiyab-nvidia commented 2 months ago

Describe the bug

I'd like to run the runwayml/stable-diffusion-v1-5 pipeline using Torch Compile in the reduce-overhead mode. However, the optimized VAE decoder model complains that Torch is not found. I've attached a script to reproduce the error in the logs.

An initial debugging on my end revealed that the issue was caused by the accelerate package. The snippet runs error-free upon uninstalling accelerate.

Reproduction

import torch
from diffusers.models import AutoencoderKL

model_id = "runwayml/stable-diffusion-v1-5"
model = AutoencoderKL.from_pretrained(model_id, subfolder='vae', use_safetensors=True).to("cuda")
model.forward = model.decode
model = torch.compile(model, mode="reduce-overhead", dynamic=False, fullgraph=False)

latents = torch.randn(torch.Size([1, 4, 64, 64]), device="cuda", dtype = torch.float32)
image = model(latents)

Logs

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /workspace/demo/Diffusion/test_txt2img_compile.py:10 in <module>                                 │
│                                                                                                  │
│    7 model = torch.compile(model, mode="reduce-overhead", dynamic=False, fullgraph=False)        │
│    8                                                                                             │
│    9 latents = torch.randn(torch.Size([1, 4, 64, 64]), device="cuda", dtype = torch.float32)     │
│ ❱ 10 image = model(latents)                                                                      │
│   11                                                                                             │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1532 in _wrapped_call_impl    │
│                                                                                                  │
│   1529 │   │   if self._compiled_call_impl is not None:                                          │
│   1530 │   │   │   return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]        │
│   1531 │   │   else:                                                                             │
│ ❱ 1532 │   │   │   return self._call_impl(*args, **kwargs)                                       │
│   1533 │                                                                                         │
│   1534 │   def _call_impl(self, *args, **kwargs):                                                │
│   1535 │   │   forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.fo  │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1541 in _call_impl            │
│                                                                                                  │
│   1538 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1539 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1540 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1541 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1542 │   │                                                                                     │
│   1543 │   │   try:                                                                              │
│   1544 │   │   │   result = None                                                                 │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py:410 in _fn                   │
│                                                                                                  │
│    407 │   │   │   cleanups = [enter() for enter in self.enter_exit_hooks]                       │
│    408 │   │   │   prior = set_eval_frame(callback)                                              │
│    409 │   │   │   try:                                                                          │
│ ❱  410 │   │   │   │   return fn(*args, **kwargs)                                                │
│    411 │   │   │   finally:                                                                      │
│    412 │   │   │   │   set_eval_frame(prior)                                                     │
│    413 │   │   │   │   for cleanup in cleanups:                                                  │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1532 in _wrapped_call_impl    │
│                                                                                                  │
│   1529 │   │   if self._compiled_call_impl is not None:                                          │
│   1530 │   │   │   return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]        │
│   1531 │   │   else:                                                                             │
│ ❱ 1532 │   │   │   return self._call_impl(*args, **kwargs)                                       │
│   1533 │                                                                                         │
│   1534 │   def _call_impl(self, *args, **kwargs):                                                │
│   1535 │   │   forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.fo  │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1541 in _call_impl            │
│                                                                                                  │
│   1538 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1539 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1540 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1541 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1542 │   │                                                                                     │
│   1543 │   │   try:                                                                              │
│   1544 │   │   │   result = None                                                                 │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/diffusers/utils/accelerate_utils.py:44 in wrapper        │
│                                                                                                  │
│   41 │   │   return method                                                                       │
│   42 │                                                                                           │
│   43 │   def wrapper(self, *args, **kwargs):                                                     │
│ ❱ 44 │   │   if hasattr(self, "_hf_hook") and hasattr(self._hf_hook, "pre_forward"):             │
│   45 │   │   │   self._hf_hook.pre_forward(self)                                                 │
│   46 │   │   return method(self, *args, **kwargs)                                                │
│   47                                                                                             │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/diffusers/utils/accelerate_utils.py:44 in                │
│ torch_dynamo_resume_in_wrapper_at_44                                                             │
│                                                                                                  │
│   41 │   │   return method                                                                       │
│   42 │                                                                                           │
│   43 │   def wrapper(self, *args, **kwargs):                                                     │
│ ❱ 44 │   │   if hasattr(self, "_hf_hook") and hasattr(self._hf_hook, "pre_forward"):             │
│   45 │   │   │   self._hf_hook.pre_forward(self)                                                 │
│   46 │   │   return method(self, *args, **kwargs)                                                │
│   47                                                                                             │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
NameError: name 'torch' is not defined


### System Info

Using the NVIDIA NGC PyTorch 24.05 container - can be launched using `docker run --rm -it --gpus all -v $PWD:/workspace nvcr.io/nvidia/pytorch:24.05-py3 /bin/bash`

Additional installations
accelerate==0.20.3
diffusers==0.26.3
onnx==1.15.0
onnxruntime==1.16.3
transformers==4.33.1

### Who can help?

@sayakpaul @DN6 @yiyixuxu
asfiyab-nvidia commented 2 months ago

@sayakpaul @DN6 @yiyixuxu can you please help investigate?

sayakpaul commented 2 months ago

Cc: @muellerzr

asfiyab-nvidia commented 2 months ago

@muellerzr is this a known issue? Please link if there's another bug tracking it

github-actions[bot] commented 1 month ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

sayakpaul commented 1 month ago

Can you try updating to the nightly torch and try again?

asfiyab-nvidia commented 1 month ago

Thanks! Works with the latest version of PyTorch. Closing

asfiyab-nvidia commented 1 month ago

The accelerate package wasn't installed when I closed the issue. Can we please re-open? I'm noticing the same error with the following versions of relevant packages

torch 2.6.0.dev20240917+cu124
diffusers 0.30.3
accelerate 0.34.2
sayakpaul commented 1 month ago

Was able to get rid of it with the simple patch:

diff --git a/src/diffusers/utils/accelerate_utils.py b/src/diffusers/utils/accelerate_utils.py
index 99a8b3a47..cc14070d2 100644
--- a/src/diffusers/utils/accelerate_utils.py
+++ b/src/diffusers/utils/accelerate_utils.py
@@ -17,12 +17,15 @@ Accelerate utilities: Utilities related to accelerate

 from packaging import version

-from .import_utils import is_accelerate_available
+from .import_utils import is_accelerate_available, is_torch_available

 if is_accelerate_available():
     import accelerate

+if is_torch_available():
+    import torch
+

 def apply_forward_hook(method):
     """
@@ -33,7 +36,8 @@ def apply_forward_hook(method):
     This decorator looks inside the internal `_hf_hook` property to find a registered offload hook.

     :param method: The method to decorate. This method should be a method of a PyTorch module.
-    """
+    """ 
+
     if not is_accelerate_available():
         return method
     accelerate_version = version.parse(accelerate.__version__).base_version

@yiyixuxu if okay with you, I can open a PR. The error can be reproduced:

import torch
from diffusers.models import AutoencoderKL

model_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
model = AutoencoderKL.from_pretrained(model_id, subfolder="vae", use_safetensors=True).to("cuda")
model.decode = torch.compile(model.decode, mode="reduce-overhead", dynamic=False, fullgraph=False)

latents = torch.randn(1, 4, 64, 64, device="cuda", dtype=torch.float32)
print(latents.shape)
image = model.decode(latents)

Only happens (or that is so we know) when a user compiles the decode() function.

yiyixuxu commented 1 month ago

so this would work I had no idea and would be very interesting to know why the torch not found error, but unfortunately don't have time at the moment to investigate since it is not something broken

import torch
from diffusers.models import AutoencoderKL

model_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
model = AutoencoderKL.from_pretrained(model_id, subfolder="vae", use_safetensors=True).to("cuda")
model.decode = torch.compile(model.decode, mode="reduce-overhead", dynamic=False, fullgraph=False)

latents = torch.randn(1, 4, 64, 64, device="cuda", dtype=torch.float32)
print(latents.shape)
- image = model.decode(latents)
+ image = model.decode(latents, return_dict=False)[0]
print(image.shape)
asfiyab-nvidia commented 1 month ago

Thanks @yiyixuxu . Applying the WAR you suggested is good for my use case!