Open foreverpiano opened 1 month ago
@drisspg @janeyx99
I am not sure if it is the issue of torch.compile. When I disable torch.compile, it shows OOM error.
full log
root@ef9ada440938:/workspace/attention-gym# python examples/1.py
Traceback (most recent call last):
File "examples/1.py", line 46, in <module>
hidden_states = flex_attention(query, key, value, block_mask=block_mask)
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 462, in _fn
return fn(*args, **kwargs)
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 1182, in __call__
return self._torchdynamo_orig_callable(
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 991, in __call__
result = self._inner_convert(
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 506, in __call__
return _compile(
File "/workspace/miniconda3/envs/opensora/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, **kwds)
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 862, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 272, in time_wrapper
r = func(*args, **kwargs)
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_utils_internal.py", line 85, in wrapper_function
return StrobelightCompileTimeProfiler.profile_compile_time(
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time
return func(*args, **kwargs)
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 680, in compile_inner
out_code = transform_code_object(code, transform)
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
transformations(instructions, code_options)
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 201, in _fn
return fn(*args, **kwargs)
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 622, in transform
tracer.run()
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 2657, in run
super().run()
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 925, in run
while self.step():
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 837, in step
self.dispatch_table[inst.opcode](self, inst)
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 2848, in RETURN_VALUE
self._return(inst)
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 2833, in _return
self.output.compile_subgraph(
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 1109, in compile_subgraph
self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
File "/workspace/miniconda3/envs/opensora/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, **kwds)
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 1351, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 272, in time_wrapper
r = func(*args, **kwargs)
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 1442, in call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 1423, in call_user_compiler
compiled_fn = compiler_fn(gm, self.example_inputs())
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/repro/after_dynamo.py", line 129, in __call__
compiled_gm = compiler_fn(gm, example_inputs)
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/__init__.py", line 2216, in __call__
return compile_fx(model_, inputs_, config_patches=self.config)
File "/workspace/miniconda3/envs/opensora/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, **kwds)
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_inductor/compile_fx.py", line 1604, in compile_fx
return aot_autograd(
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/backends/common.py", line 69, in __call__
cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 1005, in aot_module_simplified
compiled_fn = dispatch_and_compile()
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 994, in dispatch_and_compile
compiled_fn, _ = create_aot_dispatcher_function(
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 272, in time_wrapper
r = func(*args, **kwargs)
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 716, in create_aot_dispatcher_function
compiled_fn, fw_metadata = compiler_fn(
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 565, in aot_dispatch_autograd
compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args)
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 272, in time_wrapper
r = func(*args, **kwargs)
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_inductor/compile_fx.py", line 1510, in fw_compiler_base
return inner_compile(
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/repro/after_aot.py", line 84, in debug_wrapper
inner_compiled_fn = compiler_fn(gm, example_inputs)
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_inductor/debug.py", line 307, in inner
return fn(*args, **kwargs)
File "/workspace/miniconda3/envs/opensora/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, **kwds)
File "/workspace/miniconda3/envs/opensora/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, **kwds)
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_inductor/compile_fx.py", line 440, in wrapper
return fn(*args, **kwargs)
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 272, in time_wrapper
r = func(*args, **kwargs)
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_inductor/compile_fx.py", line 654, in compile_fx_inner
compiled_graph = fx_codegen_and_compile(
File "/workspace/miniconda3/envs/opensora/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, **kwds)
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_inductor/compile_fx.py", line 929, in fx_codegen_and_compile
compiled_fn = graph.compile_to_fn()
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_inductor/graph.py", line 1835, in compile_to_fn
return self.compile_to_module().call
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 272, in time_wrapper
r = func(*args, **kwargs)
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_inductor/graph.py", line 1789, in compile_to_module
mod = PyCodeCache.load_by_key_path(
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_inductor/codecache.py", line 3000, in load_by_key_path
mod = _reload_python_module(key, path)
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_inductor/runtime/compile_tasks.py", line 45, in _reload_python_module
exec(code, mod.__dict__, mod.__dict__)
File "/tmp/torchinductor_root/at/cat26xdxgms2f3opw6njqxwfudtsx5646ukznfn6472a6d6cs3j3.py", line 393, in <module>
async_compile.wait(globals())
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_inductor/async_compile.py", line 261, in wait
scope[key] = result.result()
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_inductor/codecache.py", line 3453, in result
result = self.future.result()
File "/workspace/miniconda3/envs/opensora/lib/python3.8/concurrent/futures/_base.py", line 444, in result
return self.__get_result()
File "/workspace/miniconda3/envs/opensora/lib/python3.8/concurrent/futures/_base.py", line 389, in __get_result
raise self._exception
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
SubprocException: An exception occurred in a subprocess:
Traceback (most recent call last):
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/triton/language/core.py", line 35, in wrapper
return fn(*args, **kwargs)
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/triton/language/core.py", line 1220, in full
shape = _shape_check_impl(shape)
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/triton/language/core.py", line 1205, in _shape_check_impl
raise ValueError(f"Shape element {i} must be a power of 2")
ValueError: Shape element 1 must be a power of 2
The above exception was the direct cause of the following exception:
triton.compiler.errors.CompilationError: at 10:11:
def zeros(shape, dtype):
"""
Returns a tensor filled with the scalar value 0 for the given :code:`shape` and :code:`dtype`.
:param shape: Shape of the new array, e.g., (8, 16) or (8, )
:type shape: tuple of ints
:param dtype: Data-type of the new array, e.g., :code:`tl.float16`
:type dtype: DType
"""
return core.full(shape, 0, dtype)
^
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_inductor/compile_worker/subproc_pool.py", line 270, in do_job
result = job()
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_inductor/runtime/compile_tasks.py", line 68, in _worker_compile_triton
load_kernel().precompile(warm_cache_only=True)
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 234, in precompile
compiled_binary, launcher = self._precompile_config(
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 418, in _precompile_config
triton.compile(*compile_args, **compile_kwargs),
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/triton/compiler/compiler.py", line 276, in compile
module = src.make_ir(options, codegen_fns, context)
File "/workspace/miniconda3/envs/opensora/lib/python3.8/site-packages/triton/compiler/compiler.py", line 113, in make_ir
return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns)
triton.compiler.errors.CompilationError: at 88:10:
sparse_idx_h = off_h % SPARSE_H
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
SPARSE_Q_BLOCK_CNT: tl.constexpr = tl.cdiv(Q_LEN, SPARSE_Q_BLOCK_SIZE)
SPARSE_KV_BLOCK_CNT: tl.constexpr = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE)
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
^
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
cc @Chillee
I am able to repro
The easiest thing to unblock on your end in mean time would be to pad the head_dim to 128 with 0s and then slice the output. We will likely update this to be opaque to the users of flex
The easiest thing to unblock on your end in mean time would be to pad the head_dim to 128 with 0s and then slice the output. We will likely update this to be opaque to the users of flex
It seems it will affect denominator sqrt(d_k) (96->128), thus changing the value of attn. And one more thing is the time of cost is double because of unnecessary padding.
It seems it will affect denominator sqrt(d_k) (96->128), thus changing the value of attn
You can pass in your own scale to FlexAttention. I'd agree the performance is worse, but it should be only be about 33% at maximum.
I have a fix: https://github.com/pytorch/pytorch/pull/133495, hitting Triton segfault currently in backwards
thanks for your update. wait for fixing the right result.
🐛 Describe the bug
when changing 96->64, it works rightly. So how can I use head=96?
Versions
nightly version
cc @ezyang @chauhang @penguinwu