chengzeyi / stable-fast

Best inference performance optimization framework for HuggingFace Diffusers on NVIDIA GPUs.
MIT License
1.16k stars 70 forks source link

torch.jit encounters xformers::efficient_attention_forward_cutlass and throws a RuntimeError #107

Closed HoiM closed 8 months ago

HoiM commented 8 months ago

I'm converting the pipeline of AnimateDiff using Stable-fast. This model uses a custom pipeline based on diffusers 0.11.1.

The config is set as below. (I turned on xformers on the pipeline so I did not enable xformers in the config)

import sfast.compilers.diffusion_pipeline_compiler
sf_config = sfast.compilers.diffusion_pipeline_compiler.CompilationConfig.Default()
sf_config.enable_xformers = False
sf_config.enable_triton = True
pipeline = sfast.compilers.diffusion_pipeline_compiler.compile(pipeline, sf_config)

But I got this error:

Traceback (most recent call last):

  File "/path/to/envs/a100-codelab/lib/python3.8/runpy.py", line 194, in _run_module_as_main

    return _run_code(code, main_globals, None,

  File "/path/to/envs/a100-codelab/lib/python3.8/runpy.py", line 87, in _run_code

    exec(code, run_globals)

  File "/path/to/aigc/animate-diff/AnimateDiff/scripts/animate_sf.py", line 204, in <module>

    main(args)

  File "/path/to/envs/a100-codelab/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context

    return func(*args, **kwargs)

  File "/path/to/aigc/animate-diff/AnimateDiff/scripts/animate_sf.py", line 165, in main

    sample = pipeline(

  File "/path/to/envs/a100-codelab/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context

    return func(*args, **kwargs)

  File "/path/to/aigc/animate-diff/AnimateDiff/animatediff/pipelines/pipeline_animation.py", line 424, in __call__

    down_block_additional_residuals, mid_block_additional_residual = self.controlnet(

  File "/path/to/envs/a100-codelab/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl

    return self._call_impl(*args, **kwargs)

  File "/path/to/envs/a100-codelab/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl

    return forward_call(*args, **kwargs)

  File "/path/to/envs/a100-codelab/lib/python3.8/site-packages/sfast/jit/trace_helper.py", line 51, in wrapper

    traced_m, call_helper = trace_with_kwargs(

  File "/path/to/envs/a100-codelab/lib/python3.8/site-packages/sfast/jit/trace_helper.py", line 25, in trace_with_kwargs

    traced_module = better_trace(TraceablePosArgOnlyModuleWrapper(func),

  File "/path/to/envs/a100-codelab/lib/python3.8/site-packages/sfast/jit/utils.py", line 32, in better_trace

    script_module = torch.jit.trace(func, *args, **kwargs)

  File "/path/to/envs/a100-codelab/lib/python3.8/site-packages/torch/jit/_trace.py", line 798, in trace

    return trace_module(

  File "/path/to/envs/a100-codelab/lib/python3.8/site-packages/torch/jit/_trace.py", line 1065, in trace_module

    module._c._create_method_from_trace(
  File "/path/to/envs/a100-codelab/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl

    return self._call_impl(*args, **kwargs)

  File "/path/to/envs/a100-codelab/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl

    return forward_call(*args, **kwargs)

  File "/path/to/envs/a100-codelab/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1508, in _slow_forward

    result = self.forward(*input, **kwargs)

  File "/path/to/envs/a100-codelab/lib/python3.8/site-packages/sfast/jit/trace_helper.py", line 154, in forward

    outputs = self.module(*orig_args, **orig_kwargs)

  File "/path/to/envs/a100-codelab/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl

    return self._call_impl(*args, **kwargs)

  File "/path/to/envs/a100-codelab/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl

    return forward_call(*args, **kwargs)

  File "/path/to/envs/a100-codelab/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1508, in _slow_forward

    result = self.forward(*input, **kwargs)

  File "/path/to/envs/a100-codelab/lib/python3.8/site-packages/sfast/jit/trace_helper.py", line 89, in forward

    return self.func(*args, **kwargs)

  File "/path/to/aigc/animate-diff/AnimateDiff/animatediff/models/sparse_controlnet.py", line 527, in forward

    sample, res_samples = downsample_block(

  File "/path/to/envs/a100-codelab/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl

    return self._call_impl(*args, **kwargs)

  File "/path/to/envs/a100-codelab/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl

    return forward_call(*args, **kwargs)

  File "/path/to/envs/a100-codelab/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1508, in _slow_forward

    result = self.forward(*input, **kwargs)

  File "/path/to/aigc/animate-diff/AnimateDiff/animatediff/models/unet_blocks.py", line 408, in forward

    hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample

  File "/path/to/envs/a100-codelab/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl

    return self._call_impl(*args, **kwargs)

  File "/path/to/envs/a100-codelab/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl

    return forward_call(*args, **kwargs)

  File "/path/to/envs/a100-codelab/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1508, in _slow_forward

    result = self.forward(*input, **kwargs)

  File "/path/to/aigc/animate-diff/AnimateDiff/animatediff/models/attention.py", line 117, in forward

    hidden_states = block(

  File "/path/to/envs/a100-codelab/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl

    return self._call_impl(*args, **kwargs)

  File "/path/to/envs/a100-codelab/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl

    return forward_call(*args, **kwargs)

  File "/path/to/envs/a100-codelab/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1508, in _slow_forward

    result = self.forward(*input, **kwargs)

  File "/path/to/aigc/animate-diff/AnimateDiff/animatediff/models/attention.py", line 273, in forward

    hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states

  File "/path/to/envs/a100-codelab/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl

    return self._call_impl(*args, **kwargs)

  File "/path/to/envs/a100-codelab/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl

    return forward_call(*args, **kwargs)

  File "/path/to/envs/a100-codelab/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1508, in _slow_forward

    result = self.forward(*input, **kwargs)

  File "/path/to/aigc/animate-diff/AnimateDiff/diffusers/models/attention.py", line 633, in forward

    hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)

  File "/path/to/aigc/animate-diff/AnimateDiff/diffusers/models/attention.py", line 728, in _memory_efficient_attention_xformers

    hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)

  File "/path/to/envs/a100-codelab/lib/python3.8/site-packages/xformers/ops/fmha/__init__.py", line 223, in memory_efficient_attention

    return _memory_efficient_attention(

  File "/path/to/envs/a100-codelab/lib/python3.8/site-packages/xformers/ops/fmha/__init__.py", line 321, in _memory_efficient_attention

    return _memory_efficient_attention_forward(

  File "/path/to/envs/a100-codelab/lib/python3.8/site-packages/xformers/ops/fmha/__init__.py", line 341, in _memory_efficient_attention_forward

    out, *_ = op.apply(inp, needs_gradient=False)

  File "/path/to/envs/a100-codelab/lib/python3.8/site-packages/xformers/ops/fmha/cutlass.py", line 202, in apply

    return cls.apply_bmhk(inp, needs_gradient=needs_gradient)

  File "/path/to/envs/a100-codelab/lib/python3.8/site-packages/xformers/ops/fmha/cutlass.py", line 266, in apply_bmhk

    out, lse, rng_seed, rng_offset = cls.OPERATOR(

  File "/path/to/envs/a100-codelab/lib/python3.8/site-packages/torch/_ops.py", line 692, in __call__

    return self._op(*args, **kwargs or {})

  File "/path/to/envs/a100-codelab/lib/python3.8/site-packages/sfast/jit/overrides.py", line 21, in __torch_function__

    return func(*args, **kwargs)

  File "/path/to/envs/a100-codelab/lib/python3.8/site-packages/torch/_ops.py", line 692, in __call__

    return self._op(*args, **kwargs or {})

RuntimeError: unsupported output type: int, from operator: xformers::efficient_attention_forward_cutlass

It seems to be a bug of xformers?

HoiM commented 8 months ago

The problem occurs on V100. When I switch to A30, stable-fast works well!

Issue closed. But a reply would be expected.

chengzeyi commented 8 months ago

@HoiM You should use stable-fast's interface to enable xformers since some patches are needed.