Open clessig opened 3 weeks ago
This error is horrible, but the underlying issue here is that tc_tokens_cell_idx = torch.cat( [i * torch.ones( l, dtype=torch.int64)
is creating tensors on CPUs.
If I add a device='cuda'
then it runs.
Yes, the device='cuda'
fixes it. Thanks!
But a more informative error message would be useful :)
I now tried to run it with backward which leads to another error (in my real code I get OOM, which shouldn't happen if things are properly lowered, but let's do step by step :)):
import code
import time
import warnings
import numpy as np
import torch
from torch.nn.attention.flex_attention import flex_attention, create_mask, create_block_mask
import astropy_healpix as hp
hlc = 4
num_healpix_cells = 12 * 4**hlc
print( f'seq_length : {num_healpix_cells}')
with warnings.catch_warnings(action="ignore"):
nbours= hp.neighbours( np.arange(num_healpix_cells), 2**hlc, order='nested').transpose()
# build adjacency matrix (smarter ways to do it ...)
nbours_mat = torch.zeros( (num_healpix_cells,num_healpix_cells), dtype=torch.bool, device='cuda')
for i in range(num_healpix_cells) :
for j in nbours[i] :
nbours_mat[i,j] = True if j>=0 else False
hp_adjacency = nbours_mat
# tc_tokens = torch.from_numpy( np.load( 'tc_tokens.npy')).to(torch.float16).to('cuda')
tc_tokens = torch.ones( [204458, 256], dtype=torch.float16, device='cuda', requires_grad=True)
tcs_lens = torch.from_numpy( np.load( './tcs_lens.npy')).to(torch.int32).to('cuda')
print( f'tc_tokens = {tc_tokens.shape}')
print( f'tcs_lens = {tcs_lens.shape}')
tc_tokens_cell_idx = torch.cat( [i * torch.ones( l, dtype=torch.int64, device='cuda')
for i,l in enumerate(tcs_lens)])
def sparsity_mask( score, b, h, q_idx, kv_idx):
return hp_adjacency[ tc_tokens_cell_idx[q_idx], tc_tokens_cell_idx[kv_idx] ]
compiled_flex_attention = torch.compile(flex_attention, dynamic=False)
toks = tc_tokens[:,:64].unsqueeze(0).unsqueeze(0)
out = compiled_flex_attention( toks, toks, toks, score_mod=sparsity_mask)
t = torch.zeros_like( out)
mse = torch.nn.MSELoss()
loss = mse( t, out)
loss.backward()
The relevant part of the error message is:
File "<template>", line 446, in top-level template code
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/select_algorithm.py", line 391, in modification
assert isinstance(
torch._inductor.exc.LoweringException: AssertionError: Expected the subgraph to be a ComputedBuffer, got <class 'NoneType'>
And that's the full error:
Traceback (most recent call last):
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/flex_attention_repro.py", line 44, in <module>
loss.backward()
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_tensor.py", line 581, in backward
torch.autograd.backward(
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/autograd/__init__.py", line 347, in backward
_engine_run_backward(
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/autograd/graph.py", line 825, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/autograd/function.py", line 307, in apply
return user_fn(self, *args)
^^^^^^^^^^^^^^^^^^^^
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 2048, in backward
out = call_compiled_backward()
^^^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1954, in call_compiled_backward
CompiledFunction.compiled_bw = aot_config.bw_compiler(
^^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/backends/common.py", line 51, in _wrapped_bw_compiler
return disable(disable(bw_compiler)(*args, **kwargs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1466, in bw_compiler
return inner_compile(
^^^^^^^^^^^^^^
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 475, in compile_fx_inner
return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/repro/after_aot.py", line 85, in debug_wrapper
inner_compiled_fn = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 661, in _compile_fx_inner
compiled_graph = FxGraphCache.load(
^^^^^^^^^^^^^^^^^^
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/codecache.py", line 1370, in load
compiled_graph = compile_fx_fn(
^^^^^^^^^^^^^^
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 570, in codegen_and_compile
compiled_graph = fx_codegen_and_compile(gm, example_inputs, **fx_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 859, in fx_codegen_and_compile
graph.run(*example_inputs)
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/graph.py", line 780, in run
return super().run(*args)
^^^^^^^^^^^^^^^^^^
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/fx/interpreter.py", line 146, in run
self.env[node] = self.run_node(node)
^^^^^^^^^^^^^^^^^^^
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/graph.py", line 1319, in run_node
result = super().run_node(n)
^^^^^^^^^^^^^^^^^^^
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/fx/interpreter.py", line 203, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/graph.py", line 1024, in call_function
raise LoweringException(e, target, args, kwargs).with_traceback(
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/graph.py", line 1021, in call_function
out = lowerings[target](*args, **kwargs) # type: ignore[index]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/lowering.py", line 361, in wrapped
out = decomp_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/kernel/flex_attention.py", line 1771, in flex_attention_backward
flex_attention_backward_template.maybe_append_choice(
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/codegen/common.py", line 2158, in maybe_append_choice
choices.append(self.generate(**kwargs))
^^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/select_algorithm.py", line 676, in generate
template = kernel.render(self.template, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/select_algorithm.py", line 484, in render
template.render(**self.template_env(), **kwargs),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/jinja2/environment.py", line 1304, in render
self.environment.handle_exception()
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/jinja2/environment.py", line 939, in handle_exception
raise rewrite_traceback_stack(source=source)
File "<template>", line 446, in top-level template code
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/select_algorithm.py", line 391, in modification
assert isinstance(
torch._inductor.exc.LoweringException: AssertionError: Expected the subgraph to be a ComputedBuffer, got <class 'NoneType'>
target: flex_attention_backward
args[0]: TensorBox(StorageBox(
InputBuffer(name='primals_1', layout=FixedLayout('cuda', torch.float16, size=[1, 1, 204458, 64], stride=[52341248, 52341248, 256, 1]))
))
args[1]: TensorBox(StorageBox(
InputBuffer(name='primals_1', layout=FixedLayout('cuda', torch.float16, size=[1, 1, 204458, 64], stride=[52341248, 52341248, 256, 1]))
))
args[2]: TensorBox(StorageBox(
InputBuffer(name='primals_1', layout=FixedLayout('cuda', torch.float16, size=[1, 1, 204458, 64], stride=[52341248, 52341248, 256, 1]))
))
args[3]: TensorBox(StorageBox(
InputBuffer(name='getitem_2', layout=FixedLayout('cuda', torch.float16, size=[1, 1, 204458, 64], stride=[13085312, 13085312, 64, 1]))
))
args[4]: TensorBox(StorageBox(
InputBuffer(name='getitem_3', layout=FixedLayout('cuda', torch.float32, size=[1, 1, 204458], stride=[204458, 204458, 1]))
))
args[5]: TensorBox(StorageBox(
InputBuffer(name='tangents_1', layout=FixedLayout('cuda', torch.float16, size=[1, 1, 204458, 64], stride=[13085312, 13085312, 64, 1]))
))
args[6]: TensorBox(StorageBox(
Pointwise(
'cuda',
torch.float32,
def inner_fn(index):
_, _, i2 = index
tmp0 = ops.constant(0, torch.float32)
return tmp0
,
ranges=[1, 1, 204458],
origin_node=full_default_4,
origins=OrderedSet([full_default_4])
)
))
args[7]: Subgraph(name='fw_graph', graph_module=<lambda>(), graph=None)
args[8]: Subgraph(name='joint_graph', graph_module=<lambda>(), graph=None)
args[9]: (TensorBox(StorageBox(
InputBuffer(name='full', layout=FixedLayout('cuda', torch.int32, size=[1, 1, 1], stride=[1, 1, 1]))
)), TensorBox(StorageBox(
InputBuffer(name='full_default', layout=FixedLayout('cuda', torch.int32, size=[1, 1, 1, 1], stride=[1, 1, 1, 1]))
)), None, None, TensorBox(StorageBox(
InputBuffer(name='convert_element_type', layout=FixedLayout('cuda', torch.int32, size=[1, 1, 1], stride=[1, 1, 1]))
)), TensorBox(StorageBox(
InputBuffer(name='convert_element_type_1', layout=FixedLayout('cuda', torch.int32, size=[1, 1, 1, 1], stride=[1, 1, 1, 1]))
)), None, None, 1073741824, 1073741824, Subgraph(name='mask_graph', graph_module=<lambda>(), graph=None))
args[10]: 0.125
args[11]: {'ROWS_GUARANTEED_SAFE': False, 'PRESCALE_QK': False, 'OUTPUT_LOGSUMEXP': True}
args[12]: (TensorBox(StorageBox(
InputBuffer(name='primals_3', layout=FixedLayout('cuda', torch.int64, size=[204458], stride=[1]))
)), TensorBox(StorageBox(
InputBuffer(name='primals_2', layout=FixedLayout('cuda', torch.bool, size=[3072, 3072], stride=[3072, 1]))
)))
🐛 Describe the bug
The following code generates the compile error below:
Required input can be found here: https://cloud.ovgu.de/s/355z3P6ySK4WorB (tcs_lens.npy)
Error logs
Traceback (most recent call last): File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1446, in _call_user_compiler compiled_fn = compiler_fn(gm, self.example_inputs()) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/repro/after_dynamo.py", line 129, in call compiled_gm = compiler_fn(gm, example_inputs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/init.py", line 2234, in call return compilefx(model, inputs_, config_patches=self.config) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1521, in compile_fx return aot_autograd( ^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/backends/common.py", line 72, in call cg = aot_module_simplified(gm, example_inputs, self.kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1071, in aot_module_simplified compiled_fn = dispatch_and_compile() ^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1056, in dispatch_and_compile compiledfn, = create_aot_dispatcher_function( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 522, in create_aot_dispatcher_function return _create_aot_dispatcher_function( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 759, in _create_aot_dispatcher_function compiled_fn, fw_metadata = compiler_fn( ^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 179, in aot_dispatch_base compiled_fw = compiler(fw_module, updated_flat_args) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1350, in fw_compiler_base return _fw_compiler_base(model, example_inputs, is_inference) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1421, in _fw_compiler_base return inner_compile( ^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 475, in compile_fx_inner return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/repro/after_aot.py", line 85, in debug_wrapper inner_compiled_fn = compiler_fn(gm, example_inputs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 661, in _compile_fx_inner compiled_graph = FxGraphCache.load( ^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/codecache.py", line 1370, in load compiled_graph = compile_fx_fn( ^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 570, in codegen_and_compile compiled_graph = fx_codegen_and_compile(gm, example_inputs, fx_kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 859, in fx_codegen_and_compile graph.run(example_inputs) File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/graph.py", line 780, in run return super().run(args) ^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/fx/interpreter.py", line 146, in run self.env[node] = self.run_node(node) ^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/graph.py", line 1319, in run_node result = super().run_node(n) ^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/fx/interpreter.py", line 203, in run_node return getattr(self, n.op)(n.target, args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/graph.py", line 1024, in call_function raise LoweringException(e, target, args, kwargs).with_traceback( File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/graph.py", line 1021, in call_function out = lowerings[target](*args, *kwargs) # type: ignore[index] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/lowering.py", line 361, in wrapped out = decomp_fn(args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/kernel/flex_attention.py", line 849, in flex_attention flex_attention_template.maybe_append_choice( File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/codegen/common.py", line 2158, in maybe_append_choice choices.append(self.generate(kwargs)) ^^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/select_algorithm.py", line 676, in generate template = kernel.render(self.template, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/select_algorithm.py", line 484, in render template.render(self.template_env(), kwargs), ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/jinja2/environment.py", line 1304, in render self.environment.handle_exception() File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/jinja2/environment.py", line 939, in handle_exception raise rewrite_traceback_stack(source=source) File "", line 324, in top-level template code File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/select_algorithm.py", line 397, in modification out = subgraph.data.inner_fn(()) ^^^^^^^^^^^^^^^^^^^^^^ torch._inductor.exc.LoweringException: AttributeError: 'MultiOutput' object has no attribute 'inner_fn' target: flex_attention args[0]: TensorBox(StorageBox( InputBuffer(name='arg0_1', layout=FixedLayout('cuda', torch.float16, size=[1, 1, 204458, 64], stride=[52341248, 52341248, 256, 1])) )) args[1]: TensorBox(StorageBox( InputBuffer(name='arg0_1', layout=FixedLayout('cuda', torch.float16, size=[1, 1, 204458, 64], stride=[52341248, 52341248, 256, 1])) )) args[2]: TensorBox(StorageBox( InputBuffer(name='arg0_1', layout=FixedLayout('cuda', torch.float16, size=[1, 1, 204458, 64], stride=[52341248, 52341248, 256, 1])) )) args[3]: Subgraph(name='sdpa_score0', graph_module=(), graph=None)
args[4]: (TensorBox(StorageBox(
ComputedBuffer(name='buf8', layout=FlexibleLayout('cuda', torch.int32, size=[1, 1, 1], stride=[1, 1, 1]), data=Pointwise(
'cuda',
torch.int32,
def innerfn(index):
, , = index
tmp0 = ops.constant(1, torch.int32)
return tmp0
,
ranges=[1, 1, 1],
origin_node=full,
origins=OrderedSet([full])
))
)), TensorBox(StorageBox(
ComputedBuffer(name='buf9', layout=FlexibleLayout('cuda', torch.int32, size=[1, 1, 1, 1], stride=[1, 1, 1, 1]), data=Pointwise(
'cuda',
torch.int32,
def innerfn(index):
, , , _ = index
tmp0 = ops.constant(0, torch.int32)
return tmp0
,
ranges=[1, 1, 1, 1],
origin_node=full_default,
origins=OrderedSet([full_default])
))
)), None, None, TensorBox(StorageBox(
ComputedBuffer(name='buf10', layout=FlexibleLayout('cuda', torch.int32, size=[1, 1, 1], stride=[1, 1, 1]), data=Pointwise(
'cuda',
torch.int32,
def innerfn(index):
, , = index
tmp0 = ops.load(buf0, 0)
tmp1 = ops.to_dtype(tmp0, torch.int64, src_dtype=torch.int32)
tmp2 = ops.to_dtype(tmp1, torch.int32, src_dtype=torch.int64)
return tmp2
,
ranges=[1, 1, 1],
origin_node=convert_element_type,
origins=OrderedSet([sum_1, convert_element_type])
))
)), TensorBox(StorageBox(
ComputedBuffer(name='buf11', layout=FlexibleLayout('cuda', torch.int32, size=[1, 1, 1, 1], stride=[1, 1, 1, 1]), data=Pointwise(
'cuda',
torch.int32,
def innerfn(index):
, , , _ = index
tmp0 = ops.index_expr(0, dtype=torch.int16)
tmp1 = ops.to_dtype(tmp0, torch.int64, src_dtype=torch.int16)
tmp2 = ops.to_dtype(tmp1, torch.int32, src_dtype=torch.int64)
return tmp2
,
ranges=[1, 1, 1, 1],
origin_node=convert_element_type_1,
origins=OrderedSet([sort, convert_element_type_1])
))
)), None, None, 1073741824, 1073741824, Subgraph(name='sdpa_mask0', graph_module=(), graph=None))
args[5]: 0.125
args[6]: {'ROWS_GUARANTEED_SAFE': False, 'PRESCALE_QK': False, 'OUTPUT_LOGSUMEXP': False}
args[7]: (TensorBox(StorageBox(
InputBuffer(name='arg2_1', layout=FixedLayout('cpu', torch.int64, size=[204458], stride=[1]))
)), TensorBox(StorageBox(
InputBuffer(name='arg1_1', layout=FixedLayout('cuda', torch.bool, size=[3072, 3072], stride=[3072, 1]))
)))
args[8]: ()
The above exception was the direct cause of the following exception:
Traceback (most recent call last): File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/flex_attention_repro.py", line 46, in
out = compiled_flex_attention( toks, toks, toks, score_mod=sparsity_mask)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
return fn(*args, kwargs)
^^^^^^^^^^^^^^^^^^^
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1269, in call
return self._torchdynamo_orig_callable(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1064, in call
result = self._inner_convert(
^^^^^^^^^^^^^^^^^^^^
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 526, in call
return _compile(
^^^^^^^^^
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 924, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 666, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_utils_internal.py", line 87, in wrapper_function
return function(*args, *kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 699, in _compile_inner
out_code = transform_code_object(code, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
transformations(instructions, code_options)
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 219, in _fn
return fn(args, kwargs)
^^^^^^^^^^^^^^^^^^^
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 634, in transform
tracer.run()
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2796, in run
super().run()
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
while self.step():
^^^^^^^^^^^
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
self.dispatch_table[inst.opcode](self, inst)
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2987, in RETURN_VALUE
self._return(inst)
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2972, in _return
self.output.compile_subgraph(
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1117, in compile_subgraph
self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1369, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1416, in call_user_compiler
return self._call_user_compiler(gm)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1465, in _call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
LoweringException: AttributeError: 'MultiOutput' object has no attribute 'inner_fn'
target: flex_attention
args[0]: TensorBox(StorageBox(
InputBuffer(name='arg0_1', layout=FixedLayout('cuda', torch.float16, size=[1, 1, 204458, 64], stride=[52341248, 52341248, 256, 1]))
))
args[1]: TensorBox(StorageBox(
InputBuffer(name='arg0_1', layout=FixedLayout('cuda', torch.float16, size=[1, 1, 204458, 64], stride=[52341248, 52341248, 256, 1]))
))
args[2]: TensorBox(StorageBox(
InputBuffer(name='arg0_1', layout=FixedLayout('cuda', torch.float16, size=[1, 1, 204458, 64], stride=[52341248, 52341248, 256, 1]))
))
args[3]: Subgraph(name='sdpa_score0', graph_module=(), graph=None)
args[4]: (TensorBox(StorageBox(
ComputedBuffer(name='buf8', layout=FlexibleLayout('cuda', torch.int32, size=[1, 1, 1], stride=[1, 1, 1]), data=Pointwise(
'cuda',
torch.int32,
def innerfn(index):
, , = index
tmp0 = ops.constant(1, torch.int32)
return tmp0
,
ranges=[1, 1, 1],
origin_node=full,
origins=OrderedSet([full])
))
)), TensorBox(StorageBox(
ComputedBuffer(name='buf9', layout=FlexibleLayout('cuda', torch.int32, size=[1, 1, 1, 1], stride=[1, 1, 1, 1]), data=Pointwise(
'cuda',
torch.int32,
def innerfn(index):
, , , _ = index
tmp0 = ops.constant(0, torch.int32)
return tmp0
,
ranges=[1, 1, 1, 1],
origin_node=full_default,
origins=OrderedSet([full_default])
))
)), None, None, TensorBox(StorageBox(
ComputedBuffer(name='buf10', layout=FlexibleLayout('cuda', torch.int32, size=[1, 1, 1], stride=[1, 1, 1]), data=Pointwise(
'cuda',
torch.int32,
def innerfn(index):
, , = index
tmp0 = ops.load(buf0, 0)
tmp1 = ops.to_dtype(tmp0, torch.int64, src_dtype=torch.int32)
tmp2 = ops.to_dtype(tmp1, torch.int32, src_dtype=torch.int64)
return tmp2
,
ranges=[1, 1, 1],
origin_node=convert_element_type,
origins=OrderedSet([sum_1, convert_element_type])
))
)), TensorBox(StorageBox(
ComputedBuffer(name='buf11', layout=FlexibleLayout('cuda', torch.int32, size=[1, 1, 1, 1], stride=[1, 1, 1, 1]), data=Pointwise(
'cuda',
torch.int32,
def innerfn(index):
, , , _ = index
tmp0 = ops.index_expr(0, dtype=torch.int16)
tmp1 = ops.to_dtype(tmp0, torch.int64, src_dtype=torch.int16)
tmp2 = ops.to_dtype(tmp1, torch.int32, src_dtype=torch.int64)
return tmp2
,
ranges=[1, 1, 1, 1],
origin_node=convert_element_type_1,
origins=OrderedSet([sort, convert_element_type_1])
))
)), None, None, 1073741824, 1073741824, Subgraph(name='sdpa_mask0', graph_module=(), graph=None))
args[5]: 0.125
args[6]: {'ROWS_GUARANTEED_SAFE': False, 'PRESCALE_QK': False, 'OUTPUT_LOGSUMEXP': False}
args[7]: (TensorBox(StorageBox(
InputBuffer(name='arg2_1', layout=FixedLayout('cpu', torch.int64, size=[204458], stride=[1]))
)), TensorBox(StorageBox(
InputBuffer(name='arg1_1', layout=FixedLayout('cuda', torch.bool, size=[3072, 3072], stride=[3072, 1]))
)))
args[8]: ()
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting: import torch._dynamo torch._dynamo.config.suppress_errors = True
Minified repro
W1102 11:23:50.178000 900974 pyenv312/lib/python3.12/site-packages/torch/_dynamo/repro/after_dynamo.py:110] [0/0] Compiled Fx GraphModule failed. Creating script to minify the error. W1102 11:23:50.367000 900974 pyenv312/lib/python3.12/site-packages/torch/_dynamo/debug_utils.py:279] [0/0] Writing minified repro to: W1102 11:23:50.367000 900974 pyenv312/lib/python3.12/site-packages/torch/_dynamo/debug_utils.py:279] [0/0] /gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/torch_compile_debug/run_2024_11_02_11_23_50_182897-pid_900974/minifier/minifier_launcher.py Traceback (most recent call last): File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1446, in _call_user_compiler compiled_fn = compiler_fn(gm, self.example_inputs()) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/repro/after_dynamo.py", line 107, in call compiled_gm = compiler_fn(copy.deepcopy(gm), example_inputs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/init.py", line 2234, in call return compilefx(model, inputs_, config_patches=self.config) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1521, in compile_fx return aot_autograd( ^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/backends/common.py", line 72, in call cg = aot_module_simplified(gm, example_inputs, self.kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1071, in aot_module_simplified compiled_fn = dispatch_and_compile() ^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1056, in dispatch_and_compile compiledfn, = create_aot_dispatcher_function( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 522, in create_aot_dispatcher_function return _create_aot_dispatcher_function( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 759, in _create_aot_dispatcher_function compiled_fn, fw_metadata = compiler_fn( ^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 179, in aot_dispatch_base compiled_fw = compiler(fw_module, updated_flat_args) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1350, in fw_compiler_base return _fw_compiler_base(model, example_inputs, is_inference) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1421, in _fw_compiler_base return inner_compile( ^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 475, in compile_fx_inner return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/repro/after_aot.py", line 85, in debug_wrapper inner_compiled_fn = compiler_fn(gm, example_inputs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 661, in _compile_fx_inner compiled_graph = FxGraphCache.load( ^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/codecache.py", line 1370, in load compiled_graph = compile_fx_fn( ^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 570, in codegen_and_compile compiled_graph = fx_codegen_and_compile(gm, example_inputs, fx_kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 859, in fx_codegen_and_compile graph.run(example_inputs) File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/graph.py", line 780, in run return super().run(args) ^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/fx/interpreter.py", line 146, in run self.env[node] = self.run_node(node) ^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/graph.py", line 1319, in run_node result = super().run_node(n) ^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/fx/interpreter.py", line 203, in run_node return getattr(self, n.op)(n.target, args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/graph.py", line 1024, in call_function raise LoweringException(e, target, args, kwargs).with_traceback( File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/graph.py", line 1021, in call_function out = lowerings[target](*args, *kwargs) # type: ignore[index] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/lowering.py", line 361, in wrapped out = decomp_fn(args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/kernel/flex_attention.py", line 849, in flex_attention flex_attention_template.maybe_append_choice( File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/codegen/common.py", line 2158, in maybe_append_choice choices.append(self.generate(kwargs)) ^^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/select_algorithm.py", line 676, in generate template = kernel.render(self.template, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/select_algorithm.py", line 484, in render template.render(self.template_env(), kwargs), ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/jinja2/environment.py", line 1304, in render self.environment.handle_exception() File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/jinja2/environment.py", line 939, in handle_exception raise rewrite_traceback_stack(source=source) File "", line 324, in top-level template code File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_inductor/select_algorithm.py", line 397, in modification out = subgraph.data.inner_fn(()) ^^^^^^^^^^^^^^^^^^^^^^ torch._inductor.exc.LoweringException: AttributeError: 'MultiOutput' object has no attribute 'inner_fn' target: flex_attention args[0]: TensorBox(StorageBox( InputBuffer(name='arg0_1', layout=FixedLayout('cuda', torch.float16, size=[1, 1, 204458, 64], stride=[52341248, 52341248, 256, 1])) )) args[1]: TensorBox(StorageBox( InputBuffer(name='arg0_1', layout=FixedLayout('cuda', torch.float16, size=[1, 1, 204458, 64], stride=[52341248, 52341248, 256, 1])) )) args[2]: TensorBox(StorageBox( InputBuffer(name='arg0_1', layout=FixedLayout('cuda', torch.float16, size=[1, 1, 204458, 64], stride=[52341248, 52341248, 256, 1])) )) args[3]: Subgraph(name='sdpa_score0', graph_module=(), graph=None)
args[4]: (TensorBox(StorageBox(
ComputedBuffer(name='buf8', layout=FlexibleLayout('cuda', torch.int32, size=[1, 1, 1], stride=[1, 1, 1]), data=Pointwise(
'cuda',
torch.int32,
def innerfn(index):
, , = index
tmp0 = ops.constant(1, torch.int32)
return tmp0
,
ranges=[1, 1, 1],
origin_node=full,
origins=OrderedSet([full])
))
)), TensorBox(StorageBox(
ComputedBuffer(name='buf9', layout=FlexibleLayout('cuda', torch.int32, size=[1, 1, 1, 1], stride=[1, 1, 1, 1]), data=Pointwise(
'cuda',
torch.int32,
def innerfn(index):
, , , _ = index
tmp0 = ops.constant(0, torch.int32)
return tmp0
,
ranges=[1, 1, 1, 1],
origin_node=full_default,
origins=OrderedSet([full_default])
))
)), None, None, TensorBox(StorageBox(
ComputedBuffer(name='buf10', layout=FlexibleLayout('cuda', torch.int32, size=[1, 1, 1], stride=[1, 1, 1]), data=Pointwise(
'cuda',
torch.int32,
def innerfn(index):
, , = index
tmp0 = ops.load(buf0, 0)
tmp1 = ops.to_dtype(tmp0, torch.int64, src_dtype=torch.int32)
tmp2 = ops.to_dtype(tmp1, torch.int32, src_dtype=torch.int64)
return tmp2
,
ranges=[1, 1, 1],
origin_node=convert_element_type,
origins=OrderedSet([sum_1, convert_element_type])
))
)), TensorBox(StorageBox(
ComputedBuffer(name='buf11', layout=FlexibleLayout('cuda', torch.int32, size=[1, 1, 1, 1], stride=[1, 1, 1, 1]), data=Pointwise(
'cuda',
torch.int32,
def innerfn(index):
, , , _ = index
tmp0 = ops.index_expr(0, dtype=torch.int16)
tmp1 = ops.to_dtype(tmp0, torch.int64, src_dtype=torch.int16)
tmp2 = ops.to_dtype(tmp1, torch.int32, src_dtype=torch.int64)
return tmp2
,
ranges=[1, 1, 1, 1],
origin_node=convert_element_type_1,
origins=OrderedSet([sort, convert_element_type_1])
))
)), None, None, 1073741824, 1073741824, Subgraph(name='sdpa_mask0', graph_module=(), graph=None))
args[5]: 0.125
args[6]: {'ROWS_GUARANTEED_SAFE': False, 'PRESCALE_QK': False, 'OUTPUT_LOGSUMEXP': False}
args[7]: (TensorBox(StorageBox(
InputBuffer(name='arg2_1', layout=FixedLayout('cpu', torch.int64, size=[204458], stride=[1]))
)), TensorBox(StorageBox(
InputBuffer(name='arg1_1', layout=FixedLayout('cuda', torch.bool, size=[3072, 3072], stride=[3072, 1]))
)))
args[8]: ()
The above exception was the direct cause of the following exception:
Traceback (most recent call last): File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/flex_attention_repro.py", line 46, in
out = compiled_flex_attention( toks, toks, toks, score_mod=sparsity_mask)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
return fn(*args, kwargs)
^^^^^^^^^^^^^^^^^^^
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1269, in call
return self._torchdynamo_orig_callable(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1064, in call
result = self._inner_convert(
^^^^^^^^^^^^^^^^^^^^
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 526, in call
return _compile(
^^^^^^^^^
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 924, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 666, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_utils_internal.py", line 87, in wrapper_function
return function(*args, *kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 699, in _compile_inner
out_code = transform_code_object(code, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
transformations(instructions, code_options)
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 219, in _fn
return fn(args, kwargs)
^^^^^^^^^^^^^^^^^^^
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 634, in transform
tracer.run()
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2796, in run
super().run()
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
while self.step():
^^^^^^^^^^^
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
self.dispatch_table[inst.opcode](self, inst)
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2987, in RETURN_VALUE
self._return(inst)
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2972, in _return
self.output.compile_subgraph(
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1117, in compile_subgraph
self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1369, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1416, in call_user_compiler
return self._call_user_compiler(gm)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/pyenv312/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1465, in _call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
LoweringException: AttributeError: 'MultiOutput' object has no attribute 'inner_fn'
target: flex_attention
args[0]: TensorBox(StorageBox(
InputBuffer(name='arg0_1', layout=FixedLayout('cuda', torch.float16, size=[1, 1, 204458, 64], stride=[52341248, 52341248, 256, 1]))
))
args[1]: TensorBox(StorageBox(
InputBuffer(name='arg0_1', layout=FixedLayout('cuda', torch.float16, size=[1, 1, 204458, 64], stride=[52341248, 52341248, 256, 1]))
))
args[2]: TensorBox(StorageBox(
InputBuffer(name='arg0_1', layout=FixedLayout('cuda', torch.float16, size=[1, 1, 204458, 64], stride=[52341248, 52341248, 256, 1]))
))
args[3]: Subgraph(name='sdpa_score0', graph_module=(), graph=None)
args[4]: (TensorBox(StorageBox(
ComputedBuffer(name='buf8', layout=FlexibleLayout('cuda', torch.int32, size=[1, 1, 1], stride=[1, 1, 1]), data=Pointwise(
'cuda',
torch.int32,
def innerfn(index):
, , = index
tmp0 = ops.constant(1, torch.int32)
return tmp0
,
ranges=[1, 1, 1],
origin_node=full,
origins=OrderedSet([full])
))
)), TensorBox(StorageBox(
ComputedBuffer(name='buf9', layout=FlexibleLayout('cuda', torch.int32, size=[1, 1, 1, 1], stride=[1, 1, 1, 1]), data=Pointwise(
'cuda',
torch.int32,
def innerfn(index):
, , , _ = index
tmp0 = ops.constant(0, torch.int32)
return tmp0
,
ranges=[1, 1, 1, 1],
origin_node=full_default,
origins=OrderedSet([full_default])
))
)), None, None, TensorBox(StorageBox(
ComputedBuffer(name='buf10', layout=FlexibleLayout('cuda', torch.int32, size=[1, 1, 1], stride=[1, 1, 1]), data=Pointwise(
'cuda',
torch.int32,
def innerfn(index):
, , = index
tmp0 = ops.load(buf0, 0)
tmp1 = ops.to_dtype(tmp0, torch.int64, src_dtype=torch.int32)
tmp2 = ops.to_dtype(tmp1, torch.int32, src_dtype=torch.int64)
return tmp2
,
ranges=[1, 1, 1],
origin_node=convert_element_type,
origins=OrderedSet([sum_1, convert_element_type])
))
)), TensorBox(StorageBox(
ComputedBuffer(name='buf11', layout=FlexibleLayout('cuda', torch.int32, size=[1, 1, 1, 1], stride=[1, 1, 1, 1]), data=Pointwise(
'cuda',
torch.int32,
def innerfn(index):
, , , _ = index
tmp0 = ops.index_expr(0, dtype=torch.int16)
tmp1 = ops.to_dtype(tmp0, torch.int64, src_dtype=torch.int16)
tmp2 = ops.to_dtype(tmp1, torch.int32, src_dtype=torch.int64)
return tmp2
,
ranges=[1, 1, 1, 1],
origin_node=convert_element_type_1,
origins=OrderedSet([sort, convert_element_type_1])
))
)), None, None, 1073741824, 1073741824, Subgraph(name='sdpa_mask0', graph_module=(), graph=None))
args[5]: 0.125
args[6]: {'ROWS_GUARANTEED_SAFE': False, 'PRESCALE_QK': False, 'OUTPUT_LOGSUMEXP': False}
args[7]: (TensorBox(StorageBox(
InputBuffer(name='arg2_1', layout=FixedLayout('cpu', torch.int64, size=[204458], stride=[1]))
)), TensorBox(StorageBox(
InputBuffer(name='arg1_1', layout=FixedLayout('cuda', torch.bool, size=[3072, 3072], stride=[3072, 1]))
)))
args[8]: ()
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
Minifier script written to /gpfs/home/ecm/ecm327663/obs6/ai-obs-experimental-transformer/torch_compile_debug/run_2024_11_02_11_23_50_182897-pid_900974/minifier/minifier_launcher.py. Run this script to find the smallest traced graph which reproduces this error.
You can suppress this exception and fall back to eager by setting: import torch._dynamo torch._dynamo.config.suppress_errors = True
Versions
Collecting environment information... PyTorch version: 2.5.1+cu124 Is debug build: False CUDA used to build PyTorch: 12.4 ROCM used to build PyTorch: N/A
OS: Red Hat Enterprise Linux 9.2 (Plow) (x86_64) GCC version: (GCC) 11.4.0 Clang version: Could not collect CMake version: Could not collect Libc version: glibc-2.34
Python version: 3.12.1 (main, Apr 29 2024, 16:28:15) [GCC Intel(R) C++ gcc 11.3.1 mode] (64-bit runtime) Python platform: Linux-5.14.0-284.30.1.el9_2.x86_64-x86_64-with-glibc2.34 Is CUDA available: True CUDA runtime version: 12.4.131 CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: NVIDIA H100 GPU 1: NVIDIA H100 GPU 2: NVIDIA H100 GPU 3: NVIDIA H100
Nvidia driver version: 535.86.10 cuDNN version: Could not collect HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True
CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 46 bits physical, 57 bits virtual Byte Order: Little Endian CPU(s): 160 On-line CPU(s) list: 0-159 Vendor ID: GenuineIntel Model name: Intel(R) Xeon(R) Platinum 8460Y+ CPU family: 6 Model: 143 Thread(s) per core: 2 Core(s) per socket: 40 Socket(s): 2 Stepping: 8 CPU max MHz: 3700.0000 CPU min MHz: 800.0000 BogoMIPS: 4000.00 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cat_l2 cdp_l3 invpcid_single intel_ppin cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect avx_vnni avx512_bf16 wbnoinvd dtherm ida arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req hfi avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b enqcmd fsrm md_clear serialize tsxldtrk pconfig arch_lbr ibt amx_bf16 avx512_fp16 amx_tile amx_int8 flush_l1d arch_capabilities Virtualization: VT-x L1d cache: 3.8 MiB (80 instances) L1i cache: 2.5 MiB (80 instances) L2 cache: 160 MiB (80 instances) L3 cache: 210 MiB (2 instances) NUMA node(s): 4 NUMA node0 CPU(s): 0-19,80-99 NUMA node1 CPU(s): 20-39,100-119 NUMA node2 CPU(s): 40-59,120-139 NUMA node3 CPU(s): 60-79,140-159 Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Retbleed: Not affected Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected
Versions of relevant libraries: [pip3] flake8==7.1.0 [pip3] mypy-extensions==1.0.0 [pip3] numpy==1.26.4 [pip3] nvidia-cublas-cu12==12.4.5.8 [pip3] nvidia-cuda-cupti-cu12==12.4.127 [pip3] nvidia-cuda-nvrtc-cu12==12.4.127 [pip3] nvidia-cuda-runtime-cu12==12.4.127 [pip3] nvidia-cudnn-cu12==9.1.0.70 [pip3] nvidia-cufft-cu12==11.2.1.3 [pip3] nvidia-curand-cu12==10.3.5.147 [pip3] nvidia-cusolver-cu12==11.6.1.9 [pip3] nvidia-cusparse-cu12==12.3.1.170 [pip3] nvidia-nccl-cu12==2.21.5 [pip3] nvidia-nvjitlink-cu12==12.4.127 [pip3] nvidia-nvtx-cu12==12.4.127 [pip3] torch==2.5.1 [pip3] triton==3.1.0 [conda] Could not collect
cc @ezyang @chauhang @penguinwu @zou3519 @ydwu4 @bdhirsh @yf225 @Chillee @drisspg @yanboliang @BoyuanFeng