huggingface / diffusers

πŸ€— Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
25.18k stars 5.21k forks source link

torch.compile + cudagraph failed with StableDiffusionXL #8879

Open Danielmic opened 2 months ago

Danielmic commented 2 months ago

Describe the bug

An error occurs when I run StatbleDiffusionXl using torch.compile (pipe.vae.decode, mode="reduce-overhead", fullgraph=True) The necessary condition for trigger this error is torch.compile(pipe.vae.decode) + mode="reduce-overhead" + dtype=torch.float16 I suspect 'upcast_vae' caused it : https://github.com/huggingface/diffusers/blob/e15a8e7f17529ebf4029eed07f982e14c7560546/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L1257

Reproduction

import torch
from diffusers import StableDiffusionXLPipeline

pipe = StableDiffusionXLPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0",
        torch_dtype=torch.float16)
pipe = pipe.to("cuda")

# pipe.unet = torch.compile(pipe.unet, backend="inductor", mode="reduce-overhead", fullgraph=True)
pipe.vae.decode = torch.compile(pipe.vae.decode, backend="inductor", mode="reduce-overhead", fullgraph=True)
# pipe.text_encoder = torch.compile(pipe.text_encoder, fullgraph=True)
# pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2, fullgraph=True)

n_steps = 4
# shapes = [[512, 512], [768, 768], [1024,1024]]
shapes = [[512, 512],]
# batch_num = [1, 2 ,4 ,8]
batch_num = [1,]
prompt = "A majestic lion jumping from a big stone at night"

for i in range(len(shapes)):
    for j in range(len(batch_num)):
        torch.compiler.cudagraph_mark_step_begin()
        shape = shapes[i]

        #warmup
        for warmup in range(2):
            _ = pipe(
                prompt=prompt,
                width=shape[0],
                height=shape[1],
                num_inference_steps=n_steps,
                denoising_end=1,
                generator=torch.manual_seed(0),
                num_images_per_prompt=batch_num[j]
                ).images

        images = pipe(
            prompt=prompt,
            width=shape[0],
            height=shape[1],
            num_inference_steps=n_steps,
            denoising_end=1,
            generator=torch.manual_seed(0),
            num_images_per_prompt=batch_num[j]
            ).images

Logs

Loading pipeline components...: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 7/7 [00:58<00:00,  8.30s/it]
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 4/4 [00:00<00:00,  5.38it/s]
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 4/4 [00:00<00:00, 13.59it/s]
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 4/4 [00:00<00:00, 14.54it/s]
Traceback (most recent call last):
  File "/workspace/test_cuda/test.py", line 40, in <module>
    images = pipe(
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py", line 1279, in __call__
    image = self.vae.decode(latents, return_dict=False)[0]
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/diffusers/utils/accelerate_utils.py", line 43, in wrapper
    def wrapper(self, *args, **kwargs):
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 36, in inner
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/aot_autograd.py", line 917, in forward
    return compiled_fn(full_args)
  File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/utils.py", line 89, in g
    return f(*args)
  File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 106, in runtime_wrapper
    all_outs = call_func_at_runtime_with_args(
  File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/utils.py", line 113, in call_func_at_runtime_with_args
    out = normalize_as_list(f(args))
  File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 152, in rng_functionalization_wrapper
    return compiled_fw(args)
  File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/codecache.py", line 906, in __call__
    return self.get_current_callable()(inputs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py", line 838, in run
    return compiled_fn(new_inputs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/cudagraph_trees.py", line 370, in deferred_cudagraphify
    return fn(inputs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py", line 784, in run
    return model(new_inputs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/cudagraph_trees.py", line 1757, in run
    out = self._run(new_inputs, function_id)
  File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/cudagraph_trees.py", line 1810, in _run
    return self.execute_node(child, new_inputs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/cudagraph_trees.py", line 1886, in execute_node
    return node.run(new_inputs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/cudagraph_trees.py", line 951, in run
    assert data_ptr == new_inputs[idx].data_ptr()
AssertionError

System Info

python 3.10.12 diffusers 0.30.0 torch 2.3.1 transformers 4.42.4 triton 2.3.1

Who can help?

cc @yiyixuxu @sayakpaul @DN6

tolgacangoz commented 2 months ago

I tried with the current PyTorch Nightly (2.5.0.dev20240716+cu121) in Colab, and it seems to work.

Danielmic commented 2 months ago

I tried with torch 2.5.0.dev20240718+cu121, then get a new error @tolgacangoz

Loading pipeline components...: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 7/7 [00:11<00:00,  1.69s/it]
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 4/4 [00:00<00:00,  6.86it/s]
Traceback (most recent call last):
  File "/workspace/test_cuda/test.py", line 33, in <module>
    _ = pipe(
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py", line 1279, in __call__
    image = self.vae.decode(latents, return_dict=False)[0]
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 435, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1170, in __call__
    return self._torchdynamo_orig_callable(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 499, in __call__
    return _compile(
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 850, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 246, in time_wrapper
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py", line 85, in wrapper_function
    return StrobelightCompileTimeProfiler.profile_compile_time(
  File "/usr/local/lib/python3.10/dist-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 668, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1284, in transform_code_object
    transformations(instructions, code_options)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 194, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 610, in transform
    tracer.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2546, in run
    super().run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 910, in run
    while self.step():
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 822, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 514, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1550, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 760, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/lazy.py", line 132, in realize_and_forward
    return getattr(self.realize(), name)(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 300, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 100, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 766, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2761, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2877, in inline_call_
    tracer.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 910, in run
    while self.step():
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 822, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 514, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1489, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 760, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 358, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 300, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 100, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 766, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2761, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2877, in inline_call_
    tracer.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 910, in run
    while self.step():
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 822, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 514, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1489, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 760, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/nn_module.py", line 434, in call_function
    return tx.inline_user_function_return(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 766, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2761, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2877, in inline_call_
    tracer.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 910, in run
    while self.step():
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 822, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 514, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1550, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 760, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 358, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 300, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 100, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 766, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2761, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2877, in inline_call_
    tracer.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 910, in run
    while self.step():
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 822, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 514, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1489, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 760, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 358, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 300, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 100, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 766, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2761, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2877, in inline_call_
    tracer.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 910, in run
    while self.step():
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 822, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 514, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1562, in CALL_FUNCTION_KW
    self.call_function(fn, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 760, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 358, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 300, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 100, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 766, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2761, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2877, in inline_call_
    tracer.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 910, in run
    while self.step():
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 822, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 514, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1562, in CALL_FUNCTION_KW
    self.call_function(fn, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 760, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 358, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 300, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 100, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 766, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2761, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2877, in inline_call_
    tracer.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 910, in run
    while self.step():
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 822, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 514, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1562, in CALL_FUNCTION_KW
    self.call_function(fn, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 760, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 358, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 300, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 100, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 766, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2761, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2877, in inline_call_
    tracer.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 910, in run
    while self.step():
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 822, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 514, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1489, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 760, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 345, in call_function
    return self.obj.call_method(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/nn_module.py", line 598, in call_method
    **get_kwargs("memo", "prefix", "remove_duplicate")
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/nn_module.py", line 524, in get_kwargs
    assert_all_args_kwargs_const()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/nn_module.py", line 521, in assert_all_args_kwargs_const
    unimplemented(f"non-const NNModule method {name}")
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/exc.py", line 224, in unimplemented
    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: non-const NNModule method named_modules

from user code:
   File "/usr/local/lib/python3.10/dist-packages/diffusers/utils/accelerate_utils.py", line 46, in wrapper
    return method(self, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/diffusers/models/autoencoders/autoencoder_kl.py", line 314, in decode
    decoded = self._decode(z).sample
  File "/usr/local/lib/python3.10/dist-packages/diffusers/models/autoencoders/autoencoder_kl.py", line 285, in _decode
    dec = self.decoder(z)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1727, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/diffusers/models/autoencoders/vae.py", line 293, in forward
    upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2528, in parameters
    for name, param in self.named_parameters(recurse=recurse):
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2555, in named_parameters
    gen = self._named_members(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2492, in _named_members
    self.named_modules(prefix=prefix, remove_duplicate=remove_duplicate)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2718, in named_modules
    yield from module.named_modules(

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
tolgacangoz commented 2 months ago

I tried with madebyollin/sdxl-vae-fp16-fix's VAE and it works. This VAE isn't forced upcasting unlike the default one; thus your hypothesis might be correct! (pipe.vae.config.force_upcast = False can also be applied for the default one to test this situation.)

sayakpaul commented 2 months ago

Nice catch here. @tolgacangoz do you want to try a fix in a PR?

tolgacangoz commented 2 months ago

@Danielmic Could you share your commit e-mail address to add you as a co-author?

Danielmic commented 1 month ago

danielmccchen@gmail.com @tolgacangoz

github-actions[bot] commented 4 days ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

tolgacangoz commented 3 days ago

@Danielmic if you want, feel free to take this PR. I haven't found time for this.