pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
82.52k stars 22.21k forks source link

Flexattention: ValueError: Shape element 1 must be a power of 2 #133321

Open foreverpiano opened 1 month ago

foreverpiano commented 1 month ago

🐛 Describe the bug

from functools import lru_cache
from torch.nn.attention.flex_attention import flex_attention, create_block_mask

import torch

torch._dynamo.config.cache_size_limit = 1000

# Compile the flex_attention function
flex_attention = torch.compile(flex_attention, dynamic=False)

B = 2
H = 16
S = 28800
D = 96

query = torch.randn(
    B, H, S, D, device="cuda", dtype=torch.float16, requires_grad=True
)
key = torch.randn(
    B, H, S, D, device="cuda", dtype=torch.float16, requires_grad=True
)
value = torch.randn(
    B, H, S, D, device="cuda", dtype=torch.float16, requires_grad=True
)

d_k, seq_len = query.size(-1), query.size(-2)
block_size = seq_len // 8
frame_size = seq_len // 2

@lru_cache
def create_block_mask_cached(score_mod, B, H, M, N, device="cuda"):
    block_mask = create_block_mask(score_mod, B, H, M, N, device=device)
    return block_mask

def prefix_lm_causal_mask(b, h, q_idx, kv_idx):
    row_mask = q_idx % frame_size < block_size
    col_mask = kv_idx % frame_size < block_size
    diagonal_mask = (q_idx // block_size) == (kv_idx // block_size)
    return row_mask | col_mask | diagonal_mask

def noop(b, h, q_idx, kv_idx):
    return True

block_mask = create_block_mask_cached(prefix_lm_causal_mask, 1, 1, seq_len, seq_len)
hidden_states = flex_attention(query, key, value, block_mask=block_mask)

when changing 96->64, it works rightly. So how can I use head=96?

Versions

nightly version

cc @ezyang @chauhang @penguinwu

foreverpiano commented 1 month ago

@drisspg @janeyx99

foreverpiano commented 1 month ago

I am not sure if it is the issue of torch.compile. When I disable torch.compile, it shows OOM error.

foreverpiano commented 1 month ago

full log

root@ef9ada440938:/workspace/attention-gym# python examples/1.py 
Traceback (most recent call last):
  File "examples/1.py", line 46, in <module>
    hidden_states = flex_attention(query, key, value, block_mask=block_mask)
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 462, in _fn
    return fn(*args, **kwargs)
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 1182, in __call__
    return self._torchdynamo_orig_callable(
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 991, in __call__
    result = self._inner_convert(
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 506, in __call__
    return _compile(
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/contextlib.py", line 75, in inner
    return func(*args, **kwds)
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 862, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 272, in time_wrapper
    r = func(*args, **kwargs)
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_utils_internal.py", line 85, in wrapper_function
    return StrobelightCompileTimeProfiler.profile_compile_time(
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time
    return func(*args, **kwargs)
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 680, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
    transformations(instructions, code_options)
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 201, in _fn
    return fn(*args, **kwargs)
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 622, in transform
    tracer.run()
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 2657, in run
    super().run()
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 925, in run
    while self.step():
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 837, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 2848, in RETURN_VALUE
    self._return(inst)
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 2833, in _return
    self.output.compile_subgraph(
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 1109, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/contextlib.py", line 75, in inner
    return func(*args, **kwds)
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 1351, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 272, in time_wrapper
    r = func(*args, **kwargs)
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 1442, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 1423, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/repro/after_dynamo.py", line 129, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/__init__.py", line 2216, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/contextlib.py", line 75, in inner
    return func(*args, **kwds)
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_inductor/compile_fx.py", line 1604, in compile_fx
    return aot_autograd(
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/backends/common.py", line 69, in __call__
    cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 1005, in aot_module_simplified
    compiled_fn = dispatch_and_compile()
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 994, in dispatch_and_compile
    compiled_fn, _ = create_aot_dispatcher_function(
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 272, in time_wrapper
    r = func(*args, **kwargs)
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 716, in create_aot_dispatcher_function
    compiled_fn, fw_metadata = compiler_fn(
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 565, in aot_dispatch_autograd
    compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args)
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 272, in time_wrapper
    r = func(*args, **kwargs)
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_inductor/compile_fx.py", line 1510, in fw_compiler_base
    return inner_compile(
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/repro/after_aot.py", line 84, in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_inductor/debug.py", line 307, in inner
    return fn(*args, **kwargs)
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/contextlib.py", line 75, in inner
    return func(*args, **kwds)
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/contextlib.py", line 75, in inner
    return func(*args, **kwds)
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_inductor/compile_fx.py", line 440, in wrapper
    return fn(*args, **kwargs)
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 272, in time_wrapper
    r = func(*args, **kwargs)
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_inductor/compile_fx.py", line 654, in compile_fx_inner
    compiled_graph = fx_codegen_and_compile(
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/contextlib.py", line 75, in inner
    return func(*args, **kwds)
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_inductor/compile_fx.py", line 929, in fx_codegen_and_compile
    compiled_fn = graph.compile_to_fn()
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_inductor/graph.py", line 1835, in compile_to_fn
    return self.compile_to_module().call
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 272, in time_wrapper
    r = func(*args, **kwargs)
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_inductor/graph.py", line 1789, in compile_to_module
    mod = PyCodeCache.load_by_key_path(
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_inductor/codecache.py", line 3000, in load_by_key_path
    mod = _reload_python_module(key, path)
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_inductor/runtime/compile_tasks.py", line 45, in _reload_python_module
    exec(code, mod.__dict__, mod.__dict__)
  File "/tmp/torchinductor_root/at/cat26xdxgms2f3opw6njqxwfudtsx5646ukznfn6472a6d6cs3j3.py", line 393, in <module>
    async_compile.wait(globals())
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_inductor/async_compile.py", line 261, in wait
    scope[key] = result.result()
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_inductor/codecache.py", line 3453, in result
    result = self.future.result()
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/concurrent/futures/_base.py", line 444, in result
    return self.__get_result()
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/concurrent/futures/_base.py", line 389, in __get_result
    raise self._exception
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
SubprocException: An exception occurred in a subprocess:

Traceback (most recent call last):
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/triton/language/core.py", line 35, in wrapper
    return fn(*args, **kwargs)
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/triton/language/core.py", line 1220, in full
    shape = _shape_check_impl(shape)
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/triton/language/core.py", line 1205, in _shape_check_impl
    raise ValueError(f"Shape element {i} must be a power of 2")
ValueError: Shape element 1 must be a power of 2

The above exception was the direct cause of the following exception:

triton.compiler.errors.CompilationError: at 10:11:
def zeros(shape, dtype):
    """
    Returns a tensor filled with the scalar value 0 for the given :code:`shape` and :code:`dtype`.

    :param shape: Shape of the new array, e.g., (8, 16) or (8, )
    :type shape: tuple of ints
    :param dtype: Data-type of the new array, e.g., :code:`tl.float16`
    :type dtype: DType
    """
    return core.full(shape, 0, dtype)
           ^

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_inductor/compile_worker/subproc_pool.py", line 270, in do_job
    result = job()
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_inductor/runtime/compile_tasks.py", line 68, in _worker_compile_triton
    load_kernel().precompile(warm_cache_only=True)
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 234, in precompile
    compiled_binary, launcher = self._precompile_config(
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 418, in _precompile_config
    triton.compile(*compile_args, **compile_kwargs),
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/triton/compiler/compiler.py", line 276, in compile
    module = src.make_ir(options, codegen_fns, context)
  File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/triton/compiler/compiler.py", line 113, in make_ir
    return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns)
triton.compiler.errors.CompilationError: at 88:10:
    sparse_idx_h = off_h % SPARSE_H

    SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
    SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)

    SPARSE_Q_BLOCK_CNT: tl.constexpr = tl.cdiv(Q_LEN, SPARSE_Q_BLOCK_SIZE)
    SPARSE_KV_BLOCK_CNT: tl.constexpr = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE)

    # initialize pointer to m and l
    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
    l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
    acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
          ^

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
janeyx99 commented 1 month ago

cc @Chillee

drisspg commented 1 month ago

I am able to repro

drisspg commented 1 month ago

The easiest thing to unblock on your end in mean time would be to pad the head_dim to 128 with 0s and then slice the output. We will likely update this to be opaque to the users of flex

foreverpiano commented 1 month ago

The easiest thing to unblock on your end in mean time would be to pad the head_dim to 128 with 0s and then slice the output. We will likely update this to be opaque to the users of flex


It seems it will affect denominator sqrt(d_k) (96->128), thus changing the value of attn. And one more thing is the time of cost is double because of unnecessary padding.

Chillee commented 1 month ago

It seems it will affect denominator sqrt(d_k) (96->128), thus changing the value of attn

You can pass in your own scale to FlexAttention. I'd agree the performance is worse, but it should be only be about 33% at maximum.

drisspg commented 1 month ago

I have a fix: https://github.com/pytorch/pytorch/pull/133495, hitting Triton segfault currently in backwards

foreverpiano commented 1 month ago

thanks for your update. wait for fixing the right result.