when triton matmul is enabled by setting the config.triton.mm to either "autotune" or "triton", it crashes, complaining "BLOCK_K" which seems a triton kernel's parameter
Error logs
~/project/sandbox$ python test_bert.py
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
/usr/local/lib/python3.9/dist-packages/torch/_dynamo/eval_frame.py:372: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled.Consider setting `torch.set_float32_matmul_precision('high')`
warnings.warn(
/usr/local/lib/python3.9/dist-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: /usr/local/lib/python3.9/dist-packages/torchvision/image.so: undefined symbol: _ZN3c107WarningC1ENS_7variantIJNS0_11UserWarningENS0_18DeprecationWarningEEEERKNS_14SourceLocationERKSsb
warn(f"Failed to load image Python extension: {e}")
sequnece length = 12
Traceback (most recent call last):
File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/graph.py", line 296, in call_function
out = lowerings[target](*args, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/lowering.py", line 222, in wrapped
return decomp_fn(*args, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/lowering.py", line 842, in mm
return TensorBox.create(ir.MatrixMultiply.create(a, b))
File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/ir.py", line 2803, in create
kernel = tuned_mm(
File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/codegen/autotuner.py", line 202, in tuned_mm
timing, _, _ = autotune._bench(runnable_kernel, *run_args, **run_kwargs)
File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/codegen/autotuner.py", line 40, in _bench
return do_bench(kernel_call)
File "/usr/local/lib/python3.9/dist-packages/triton/testing.py", line 140, in do_bench
fn()
File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/codegen/autotuner.py", line 36, in kernel_call
kernel(*args, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/triton_ops/matmul.py", line 134, in forward
return _matmul_out._call(a, b, out, allow_tf32)
File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/triton_ops/matmul.py", line 114, in _call
_kernel[grid](
File "/usr/local/lib/python3.9/dist-packages/triton/runtime/jit.py", line 106, in launcher
return self.run(*args, grid=grid, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/triton/runtime/autotuner.py", line 199, in run
kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs})
File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/triton_ops/autotune.py", line 560, in <lambda>
"EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
KeyError: 'BLOCK_K'
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/output_graph.py", line 676, in call_user_compiler
compiled_fn = compiler_fn(gm, self.fake_example_inputs())
File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/debug_utils.py", line 1032, in debug_wrapper
compiled_gm = compiler_fn(gm, example_inputs, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/torch/__init__.py", line 1190, in __call__
return self.compile_fn(model_, inputs_)
File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/compile_fx.py", line 398, in compile_fx
return aot_autograd(
File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/optimizations/training.py", line 78, in compiler_fn
cg = aot_module_simplified(gm, example_inputs, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/torch/_functorch/aot_autograd.py", line 2355, in aot_module_simplified
compiled_fn = create_aot_dispatcher_function(
File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/utils.py", line 88, in time_wrapper
r = func(*args, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/torch/_functorch/aot_autograd.py", line 2052, in create_aot_dispatcher_function
compiled_fn = compiler_fn(flat_fn, fake_flat_tensor_args, aot_config)
File "/usr/local/lib/python3.9/dist-packages/torch/_functorch/aot_autograd.py", line 1307, in aot_wrapper_dedupe
return compiler_fn(flat_fn, leaf_flat_args, aot_config)
File "/usr/local/lib/python3.9/dist-packages/torch/_functorch/aot_autograd.py", line 1566, in aot_dispatch_autograd
compiled_fw_func = aot_config.fw_compiler(
File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/utils.py", line 88, in time_wrapper
r = func(*args, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/compile_fx.py", line 373, in fw_compiler
return inner_compile(
File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/debug_utils.py", line 588, in debug_wrapper
compiled_fn = compiler_fn(gm, example_inputs, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/debug.py", line 223, in inner
return fn(*args, **kwargs)
File "/usr/lib/python3.9/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/compile_fx.py", line 139, in compile_fx_inner
graph.run(*example_inputs)
File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/utils.py", line 88, in time_wrapper
r = func(*args, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/graph.py", line 170, in run
return super().run(*args)
File "/usr/local/lib/python3.9/dist-packages/torch/fx/interpreter.py", line 136, in run
self.env[node] = self.run_node(node)
File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/graph.py", line 369, in run_node
result = super().run_node(n)
File "/usr/local/lib/python3.9/dist-packages/torch/fx/interpreter.py", line 177, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/graph.py", line 299, in call_function
raise LoweringException(e, target, args, kwargs) from e
torch._inductor.exc.LoweringException: KeyError: 'BLOCK_K'
target: aten.mm.default
args[0]: TensorBox(StorageBox(
ComputedBuffer(name='buf6', layout=FixedLayout('cuda', torch.float32, size=(12, 768), stride=[768, 1]), data=Pointwise(
'cuda',
torch.float32,
tmp0 = load(buf5, i1 + 768 * i0)
return tmp0
,
ranges=(12, 768),
origins={view}
))
))
args[1]: TensorBox(
ReinterpretView(
StorageBox(
InputBuffer(name='primals_6', layout=FixedLayout('cuda', torch.float32, size=[768, 768], stride=[768, 1]))
),
FixedLayout('cuda', torch.float32, size=[768, 768], stride=[1, 768]),
no origins?
)
)
While executing %mm : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%view, %permute), kwargs = {})
Original traceback:
Module stack: {'self_encoder': "<class 'transformers.models.bert.modeling_bert.BertEncoder'>", 'self_encoder_layer_0': "<class 'transformers.models.bert.modeling_bert.BertLayer'>", 'self_encoder_layer_0_attention': "<class 'transformers.models.bert.modeling_bert.BertAttention'>", 'self_encoder_layer_0_attention_self': "<class 'transformers.models.bert.modeling_bert.BertSelfAttention'>", 'self_encoder_layer_0_attention_self_query': "<class 'torch.nn.modules.linear.Linear'>"}
File "/usr/local/lib/python3.9/dist-packages/transformers/models/bert/modeling_bert.py", line 285, in forward
mixed_query_layer = self.query(hidden_states)
| File "/usr/local/lib/python3.9/dist-packages/transformers/models/bert/modeling_bert.py", line 426, in forward
self_outputs = self.self(
| File "/usr/local/lib/python3.9/dist-packages/transformers/models/bert/modeling_bert.py", line 496, in forward
self_attention_outputs = self.attention(
| File "/usr/local/lib/python3.9/dist-packages/transformers/models/bert/modeling_bert.py", line 610, in forward
layer_outputs = layer_module(
| File "/usr/local/lib/python3.9/dist-packages/transformers/models/bert/modeling_bert.py", line 1021, in forward
encoder_outputs = self.encoder(
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/A/project/sandbox/test_bert.py", line 60, in <module>
print(f"gains={seq_len}:{measure_perf(t)}")
File "/home/A/project/sandbox/test_bert.py", line 40, in measure_perf
output = opt_model(**token)
File "/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py", line 1482, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/eval_frame.py", line 83, in forward
return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/eval_frame.py", line 212, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/eval_frame.py", line 333, in catch_errors
return callback(frame, cache_size, hooks)
File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/convert_frame.py", line 480, in _convert_frame
result = inner_convert(frame, cache_size, hooks)
File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/convert_frame.py", line 103, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/utils.py", line 88, in time_wrapper
r = func(*args, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/convert_frame.py", line 339, in _convert_frame_assert
return _compile(
File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/convert_frame.py", line 400, in _compile
out_code = transform_code_object(code, transform)
File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/bytecode_transformation.py", line 341, in transform_code_object
transformations(instructions, code_options)
File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/convert_frame.py", line 387, in transform
tracer.run()
File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/symbolic_convert.py", line 1684, in run
super().run()
File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/symbolic_convert.py", line 538, in run
and self.step()
File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/symbolic_convert.py", line 501, in step
getattr(self, inst.opname)(inst)
File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/symbolic_convert.py", line 1750, in RETURN_VALUE
self.output.compile_subgraph(self)
File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/output_graph.py", line 553, in compile_subgraph
self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/output_graph.py", line 600, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/output_graph.py", line 681, in call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: debug_wrapper raised LoweringException: KeyError: 'BLOCK_K'
target: aten.mm.default
args[0]: TensorBox(StorageBox(
ComputedBuffer(name='buf6', layout=FixedLayout('cuda', torch.float32, size=(12, 768), stride=[768, 1]), data=Pointwise(
'cuda',
torch.float32,
tmp0 = load(buf5, i1 + 768 * i0)
return tmp0
,
ranges=(12, 768),
origins={view}
))
))
args[1]: TensorBox(
ReinterpretView(
StorageBox(
InputBuffer(name='primals_6', layout=FixedLayout('cuda', torch.float32, size=[768, 768], stride=[768, 1]))
),
FixedLayout('cuda', torch.float32, size=[768, 768], stride=[1, 768]),
no origins?
)
)
While executing %mm : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%view, %permute), kwargs = {})
Original traceback:
Module stack: {'self_encoder': "<class 'transformers.models.bert.modeling_bert.BertEncoder'>", 'self_encoder_layer_0': "<class 'transformers.models.bert.modeling_bert.BertLayer'>", 'self_encoder_layer_0_attention': "<class 'transformers.models.bert.modeling_bert.BertAttention'>", 'self_encoder_layer_0_attention_self': "<class 'transformers.models.bert.modeling_bert.BertSelfAttention'>", 'self_encoder_layer_0_attention_self_query': "<class 'torch.nn.modules.linear.Linear'>"}
File "/usr/local/lib/python3.9/dist-packages/transformers/models/bert/modeling_bert.py", line 285, in forward
mixed_query_layer = self.query(hidden_states)
| File "/usr/local/lib/python3.9/dist-packages/transformers/models/bert/modeling_bert.py", line 426, in forward
self_outputs = self.self(
| File "/usr/local/lib/python3.9/dist-packages/transformers/models/bert/modeling_bert.py", line 496, in forward
self_attention_outputs = self.attention(
| File "/usr/local/lib/python3.9/dist-packages/transformers/models/bert/modeling_bert.py", line 610, in forward
layer_outputs = layer_module(
| File "/usr/local/lib/python3.9/dist-packages/transformers/models/bert/modeling_bert.py", line 1021, in forward
encoder_outputs = self.encoder(
Set torch._dynamo.config.verbose=True for more information
You can suppress this exception and fall back to eager by setting:
torch._dynamo.config.suppress_errors = True
Minified repro
"""
bert
"""
import torch
import numpy as np
from transformers import BertTokenizer, BertModel
from torch._inductor import config
config.triton.mm = "autotune"
#config.triton.mm = "triton"
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained("bert-base-uncased").to(device="cuda:0")#opt_model = torch.compile(model, backend="inductor") # This is the only line of code that we changed
#opt_model = torch.compile(model, passes={"triton-autotune":True})
#opt_model = torch.compile(model, passes={"triton-mm":"triton"})
#opt_model = torch.compile(model, passes={'triton-mm': "triton", 'triton-bmm': True}) # this also fails with a different error message
def measure_perf(token, verbose=False, N=8):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
ref_runtime = list()
print(f"sequnece length = {token['input_ids'].shape[1]}")
for i in range(N):
start_event.record()
output = model(**token)
end_event.record()
torch.cuda.synchronize()
estimate_ms = start_event.elapsed_time(end_event)
if verbose:
print(f"model: estimated_ms={estimate_ms}")
if i>0 :
ref_runtime.append(estimate_ms)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
triton_runtime = list()
for i in range(N):
start_event.record()
output = opt_model(**token)
end_event.record()
torch.cuda.synchronize()
estimate_ms = start_event.elapsed_time(end_event)
if verbose:
print(f"opt_model: estimated_ms={estimate_ms}")
if i>0 :
triton_runtime.append(estimate_ms)
a = np.mean(ref_runtime)
b = np.mean(triton_runtime)
gain = a/b
print(f"model: estimated_ms={a}")
print(f"triton model: estimated_ms={b}, gain={gain}")
return gain
text = "Replace me by any text you'd like."
t = tokenizer(text, return_tensors='pt').to(device="cuda:0")
seq_len = t['input_ids'].shape[1]
print(f"gains={seq_len}:{measure_perf(t)}")
š Describe the bug
when triton matmul is enabled by setting the config.triton.mm to either "autotune" or "triton", it crashes, complaining "BLOCK_K" which seems a triton kernel's parameter
Error logs
Minified repro