Open tobiasvanderwerff opened 11 hours ago
If using torch.compile(..., mode='max-autotune', ...)
, I get a different error (also resolved by the fix above):
Traceback (most recent call last):
File "/home/azureuser/a.py", line 36, in <module>
m(q, k, v)
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1292, in __call__
return self._torchdynamo_orig_callable(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1087, in __call__
result = self._inner_convert(
^^^^^^^^^^^^^^^^^^^^
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 530, in __call__
return _compile(
^^^^^^^^^
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 933, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 675, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_utils_internal.py", line 87, in wrapper_function
return function(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 708, in _compile_inner
out_code = transform_code_object(code, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
transformations(instructions, code_options)
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 220, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 643, in transform
tracer.run()
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2776, in run
super().run()
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 979, in run
while self.step():
^^^^^^^^^^^
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 891, in step
self.dispatch_table[inst.opcode](self, inst)
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2967, in RETURN_VALUE
self._return(inst)
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2952, in _return
self.output.compile_subgraph(
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1117, in compile_subgraph
self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1369, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1416, in call_user_compiler
return self._call_user_compiler(gm)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1465, in _call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1446, in _call_user_compiler
compiled_fn = compiler_fn(gm, self.example_inputs())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/repro/after_dynamo.py", line 130, in __call__
compiled_gm = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/__init__.py", line 2235, in __call__
return compile_fx(model_, inputs_, config_patches=self.config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1272, in compile_fx
return compile_fx(
^^^^^^^^^^^
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1533, in compile_fx
return aot_autograd(
^^^^^^^^^^^^^
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/backends/common.py", line 72, in __call__
cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1071, in aot_module_simplified
compiled_fn = dispatch_and_compile()
^^^^^^^^^^^^^^^^^^^^^^
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1056, in dispatch_and_compile
compiled_fn, _ = create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 522, in create_aot_dispatcher_function
return _create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 759, in _create_aot_dispatcher_function
compiled_fn, fw_metadata = compiler_fn(
^^^^^^^^^^^^
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 179, in aot_dispatch_base
compiled_fw = compiler(fw_module, updated_flat_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1359, in fw_compiler_base
return _fw_compiler_base(model, example_inputs, is_inference)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1430, in _fw_compiler_base
return inner_compile(
^^^^^^^^^^^^^^
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 479, in compile_fx_inner
return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/repro/after_aot.py", line 85, in debug_wrapper
inner_compiled_fn = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 665, in _compile_fx_inner
compiled_graph = FxGraphCache.load(
^^^^^^^^^^^^^^^^^^
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/codecache.py", line 1425, in load
compiled_graph = compile_fx_fn(
^^^^^^^^^^^^^^
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 574, in codegen_and_compile
compiled_graph = fx_codegen_and_compile(gm, example_inputs, **fx_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 863, in fx_codegen_and_compile
graph.run(*example_inputs)
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/graph.py", line 780, in run
return super().run(*args)
^^^^^^^^^^^^^^^^^^
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/fx/interpreter.py", line 146, in run
self.env[node] = self.run_node(node)
^^^^^^^^^^^^^^^^^^^
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/graph.py", line 1357, in run_node
result = super().run_node(n)
^^^^^^^^^^^^^^^^^^^
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/fx/interpreter.py", line 203, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/graph.py", line 1023, in call_function
raise LoweringException(e, target, args, kwargs).with_traceback(
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/graph.py", line 1020, in call_function
out = lowerings[target](*args, **kwargs) # type: ignore[index]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/lowering.py", line 363, in wrapped
out = decomp_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/kernel/flex_attention.py", line 913, in flex_attention
autotune_select_algorithm(
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/select_algorithm.py", line 1729, in autotune_select_algorithm
return _ALGORITHM_SELECTOR_CACHE(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/select_algorithm.py", line 1224, in __call__
inputs_key = create_inputs_key(input_nodes)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/select_algorithm.py", line 1138, in create_inputs_key
return repr([AlgorithmSelectorCache.key_of(x) for x in input_nodes])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/select_algorithm.py", line 1138, in <listcomp>
return repr([AlgorithmSelectorCache.key_of(x) for x in input_nodes])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/select_algorithm.py", line 1698, in key_of
node.get_stride(),
^^^^^^^^^^^^^^^
File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/ir.py", line 6276, in __getattr__
fn = getattr(self.data, name)
^^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
LoweringException: AttributeError: 'View' object has no attribute 'get_stride'
target: flex_attention
args[0]: TensorBox(StorageBox(
InputBuffer(name='arg1_1', layout=FixedLayout('cuda', torch.bfloat16, size=[100, 12, 128, 64], stride=[98304, 8192, 64, 1]))
))
args[1]: TensorBox(StorageBox(
InputBuffer(name='arg2_1', layout=FixedLayout('cuda', torch.bfloat16, size=[100, 12, 128, 64], stride=[98304, 8192, 64, 1]))
))
args[2]: TensorBox(StorageBox(
InputBuffer(name='arg3_1', layout=FixedLayout('cuda', torch.bfloat16, size=[100, 12, 128, 64], stride=[98304, 8192, 64, 1]))
))
args[3]: Subgraph(name='sdpa_score0', graph_module=<lambda>(), graph=None)
args[4]: (TensorBox(StorageBox(
ComputedBuffer(name='buf2', layout=FlexibleLayout('cuda', torch.int32, size=[1, 1, 1], stride=[1, 1, 1]), data=Pointwise(
'cuda',
torch.int32,
def inner_fn(index):
_, _, _ = index
tmp0 = ops.constant(1, torch.int32)
return tmp0
,
ranges=[1, 1, 1],
origin_node=full,
origins=OrderedSet([full])
))
)), TensorBox(StorageBox(
ComputedBuffer(name='buf3', layout=FlexibleLayout('cuda', torch.int32, size=[1, 1, 1, 1], stride=[1, 1, 1, 1]), data=Pointwise(
'cuda',
torch.int32,
def inner_fn(index):
_, _, _, _ = index
tmp0 = ops.constant(0, torch.int32)
return tmp0
,
ranges=[1, 1, 1, 1],
origin_node=full_default,
origins=OrderedSet([full_default])
))
)), None, None, TensorBox(StorageBox(
ComputedBuffer(name='buf4', layout=FlexibleLayout('cuda', torch.int32, size=[1, 1, 1], stride=[1, 1, 1]), data=Pointwise(
'cuda',
torch.int32,
def inner_fn(index):
_, _, _ = index
tmp0 = ops.load(buf0, 0)
tmp1 = ops.to_dtype(tmp0, torch.int64, src_dtype=torch.int32)
tmp2 = ops.to_dtype(tmp1, torch.int32, src_dtype=torch.int64)
return tmp2
,
ranges=[1, 1, 1],
origin_node=convert_element_type,
origins=OrderedSet([sum_1, convert_element_type])
))
)), TensorBox(StorageBox(
ComputedBuffer(name='buf5', layout=FlexibleLayout('cuda', torch.int32, size=[1, 1, 1, 1], stride=[1, 1, 1, 1]), data=Pointwise(
'cuda',
torch.int32,
def inner_fn(index):
_, _, _, _ = index
tmp0 = ops.index_expr(0, dtype=torch.int16)
tmp1 = ops.to_dtype(tmp0, torch.int64, src_dtype=torch.int16)
tmp2 = ops.to_dtype(tmp1, torch.int32, src_dtype=torch.int64)
return tmp2
,
ranges=[1, 1, 1, 1],
origin_node=convert_element_type_1,
origins=OrderedSet([convert_element_type_1, sort])
))
)), None, None, 1073741824, 1073741824, Subgraph(name='sdpa_mask0', graph_module=<lambda>(), graph=None))
args[5]: 0.125
args[6]: {'ROWS_GUARANTEED_SAFE': False, 'PRESCALE_QK': False, 'OUTPUT_LOGSUMEXP': False}
args[7]: (TensorBox(
View(
StorageBox(
Pointwise(
'cuda',
torch.bfloat16,
def inner_fn(index):
i0, i1, i2, i3 = index
tmp0 = ops.load(arg0_1, i3 + 12 * i2 + 1536 * i1 + 196608 * i0)
tmp1 = ops.constant(2, torch.bfloat16)
tmp2 = tmp0 * tmp1
return tmp2
,
ranges=[100, 128, 128, 12],
origin_node=mul,
origins=OrderedSet([mul])
)
),
size=[100, 12, 128, 128],
reindex=lambda i0, i1, i2, i3: [ModularIndexing(196608*i0 + 16384*i1 + 128*i2 + i3, 196608, 100), ModularIndexing(16384*i1 + 128*i2 + i3, 1536, 128), ModularIndexing(16384*i1 + 128*i2 + i3, 12, 128), ModularIndexing(16384*i1 + 128*i2 + i3, 1, 12)],
origins=OrderedSet([view, mul])
)
),)
args[8]: ()
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
The following code leads to an error:
The error depends on the torch.compile mode I'm using.
If using
torch.compile(..., mode='default', ...)
, I get the following error:Notably, the error goes away if I move the following line in
generate_score_mod
to__init__
instead:Relevant specs:
2.6.0.dev20240918