pytorch / torchdynamo

A Python-level JIT compiler designed to make unmodified PyTorch programs faster.
BSD 3-Clause "New" or "Revised" License
1.01k stars 124 forks source link

crash when enabled triton matmul #2015

Closed stephen-youn closed 1 year ago

stephen-youn commented 1 year ago

šŸ› Describe the bug

when triton matmul is enabled by setting the config.triton.mm to either "autotune" or "triton", it crashes, complaining "BLOCK_K" which seems a triton kernel's parameter

Error logs

~/project/sandbox$ python test_bert.py
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
/usr/local/lib/python3.9/dist-packages/torch/_dynamo/eval_frame.py:372: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled.Consider setting `torch.set_float32_matmul_precision('high')`
  warnings.warn(
/usr/local/lib/python3.9/dist-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: /usr/local/lib/python3.9/dist-packages/torchvision/image.so: undefined symbol: _ZN3c107WarningC1ENS_7variantIJNS0_11UserWarningENS0_18DeprecationWarningEEEERKNS_14SourceLocationERKSsb
  warn(f"Failed to load image Python extension: {e}")
sequnece length = 12
Traceback (most recent call last):
  File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/graph.py", line 296, in call_function
    out = lowerings[target](*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/lowering.py", line 222, in wrapped
    return decomp_fn(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/lowering.py", line 842, in mm
    return TensorBox.create(ir.MatrixMultiply.create(a, b))
  File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/ir.py", line 2803, in create
    kernel = tuned_mm(
  File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/codegen/autotuner.py", line 202, in tuned_mm
    timing, _, _ = autotune._bench(runnable_kernel, *run_args, **run_kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/codegen/autotuner.py", line 40, in _bench
    return do_bench(kernel_call)
  File "/usr/local/lib/python3.9/dist-packages/triton/testing.py", line 140, in do_bench
    fn()
  File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/codegen/autotuner.py", line 36, in kernel_call
    kernel(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/triton_ops/matmul.py", line 134, in forward
    return _matmul_out._call(a, b, out, allow_tf32)
  File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/triton_ops/matmul.py", line 114, in _call
    _kernel[grid](
  File "/usr/local/lib/python3.9/dist-packages/triton/runtime/jit.py", line 106, in launcher
    return self.run(*args, grid=grid, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/triton/runtime/autotuner.py", line 199, in run
    kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs})
  File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/triton_ops/autotune.py", line 560, in <lambda>
    "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
KeyError: 'BLOCK_K'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/output_graph.py", line 676, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.fake_example_inputs())
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/debug_utils.py", line 1032, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/__init__.py", line 1190, in __call__
    return self.compile_fn(model_, inputs_)
  File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/compile_fx.py", line 398, in compile_fx
    return aot_autograd(
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/optimizations/training.py", line 78, in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_functorch/aot_autograd.py", line 2355, in aot_module_simplified
    compiled_fn = create_aot_dispatcher_function(
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/utils.py", line 88, in time_wrapper
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_functorch/aot_autograd.py", line 2052, in create_aot_dispatcher_function
    compiled_fn = compiler_fn(flat_fn, fake_flat_tensor_args, aot_config)
  File "/usr/local/lib/python3.9/dist-packages/torch/_functorch/aot_autograd.py", line 1307, in aot_wrapper_dedupe
    return compiler_fn(flat_fn, leaf_flat_args, aot_config)
  File "/usr/local/lib/python3.9/dist-packages/torch/_functorch/aot_autograd.py", line 1566, in aot_dispatch_autograd
    compiled_fw_func = aot_config.fw_compiler(
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/utils.py", line 88, in time_wrapper
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/compile_fx.py", line 373, in fw_compiler
    return inner_compile(
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/debug_utils.py", line 588, in debug_wrapper
    compiled_fn = compiler_fn(gm, example_inputs, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/debug.py", line 223, in inner
    return fn(*args, **kwargs)
  File "/usr/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/compile_fx.py", line 139, in compile_fx_inner
    graph.run(*example_inputs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/utils.py", line 88, in time_wrapper
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/graph.py", line 170, in run
    return super().run(*args)
  File "/usr/local/lib/python3.9/dist-packages/torch/fx/interpreter.py", line 136, in run
    self.env[node] = self.run_node(node)
  File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/graph.py", line 369, in run_node
    result = super().run_node(n)
  File "/usr/local/lib/python3.9/dist-packages/torch/fx/interpreter.py", line 177, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/graph.py", line 299, in call_function
    raise LoweringException(e, target, args, kwargs) from e
torch._inductor.exc.LoweringException: KeyError: 'BLOCK_K'
  target: aten.mm.default
  args[0]: TensorBox(StorageBox(
    ComputedBuffer(name='buf6', layout=FixedLayout('cuda', torch.float32, size=(12, 768), stride=[768, 1]), data=Pointwise(
      'cuda',
      torch.float32,
      tmp0 = load(buf5, i1 + 768 * i0)
      return tmp0
      ,
      ranges=(12, 768),
      origins={view}
    ))
  ))
  args[1]: TensorBox(
    ReinterpretView(
      StorageBox(
        InputBuffer(name='primals_6', layout=FixedLayout('cuda', torch.float32, size=[768, 768], stride=[768, 1]))
      ),
      FixedLayout('cuda', torch.float32, size=[768, 768], stride=[1, 768]),
      no origins?
    )
  )

While executing %mm : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%view, %permute), kwargs = {})
Original traceback:
Module stack: {'self_encoder': "<class 'transformers.models.bert.modeling_bert.BertEncoder'>", 'self_encoder_layer_0': "<class 'transformers.models.bert.modeling_bert.BertLayer'>", 'self_encoder_layer_0_attention': "<class 'transformers.models.bert.modeling_bert.BertAttention'>", 'self_encoder_layer_0_attention_self': "<class 'transformers.models.bert.modeling_bert.BertSelfAttention'>", 'self_encoder_layer_0_attention_self_query': "<class 'torch.nn.modules.linear.Linear'>"}
  File "/usr/local/lib/python3.9/dist-packages/transformers/models/bert/modeling_bert.py", line 285, in forward
    mixed_query_layer = self.query(hidden_states)
 |   File "/usr/local/lib/python3.9/dist-packages/transformers/models/bert/modeling_bert.py", line 426, in forward
    self_outputs = self.self(
 |   File "/usr/local/lib/python3.9/dist-packages/transformers/models/bert/modeling_bert.py", line 496, in forward
    self_attention_outputs = self.attention(
 |   File "/usr/local/lib/python3.9/dist-packages/transformers/models/bert/modeling_bert.py", line 610, in forward
    layer_outputs = layer_module(
 |   File "/usr/local/lib/python3.9/dist-packages/transformers/models/bert/modeling_bert.py", line 1021, in forward
    encoder_outputs = self.encoder(

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/A/project/sandbox/test_bert.py", line 60, in <module>
    print(f"gains={seq_len}:{measure_perf(t)}")
  File "/home/A/project/sandbox/test_bert.py", line 40, in measure_perf
    output = opt_model(**token)
  File "/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py", line 1482, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/eval_frame.py", line 83, in forward
    return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/eval_frame.py", line 212, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/eval_frame.py", line 333, in catch_errors
    return callback(frame, cache_size, hooks)
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/convert_frame.py", line 480, in _convert_frame
    result = inner_convert(frame, cache_size, hooks)
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/convert_frame.py", line 103, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/utils.py", line 88, in time_wrapper
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/convert_frame.py", line 339, in _convert_frame_assert
    return _compile(
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/convert_frame.py", line 400, in _compile
    out_code = transform_code_object(code, transform)
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/bytecode_transformation.py", line 341, in transform_code_object
    transformations(instructions, code_options)
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/convert_frame.py", line 387, in transform
    tracer.run()
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/symbolic_convert.py", line 1684, in run
    super().run()
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/symbolic_convert.py", line 538, in run
    and self.step()
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/symbolic_convert.py", line 501, in step
    getattr(self, inst.opname)(inst)
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/symbolic_convert.py", line 1750, in RETURN_VALUE
    self.output.compile_subgraph(self)
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/output_graph.py", line 553, in compile_subgraph
    self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/output_graph.py", line 600, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/output_graph.py", line 681, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: debug_wrapper raised LoweringException: KeyError: 'BLOCK_K'
  target: aten.mm.default
  args[0]: TensorBox(StorageBox(
    ComputedBuffer(name='buf6', layout=FixedLayout('cuda', torch.float32, size=(12, 768), stride=[768, 1]), data=Pointwise(
      'cuda',
      torch.float32,
      tmp0 = load(buf5, i1 + 768 * i0)
      return tmp0
      ,
      ranges=(12, 768),
      origins={view}
    ))
  ))
  args[1]: TensorBox(
    ReinterpretView(
      StorageBox(
        InputBuffer(name='primals_6', layout=FixedLayout('cuda', torch.float32, size=[768, 768], stride=[768, 1]))
      ),
      FixedLayout('cuda', torch.float32, size=[768, 768], stride=[1, 768]),
      no origins?
    )
  )

While executing %mm : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%view, %permute), kwargs = {})
Original traceback:
Module stack: {'self_encoder': "<class 'transformers.models.bert.modeling_bert.BertEncoder'>", 'self_encoder_layer_0': "<class 'transformers.models.bert.modeling_bert.BertLayer'>", 'self_encoder_layer_0_attention': "<class 'transformers.models.bert.modeling_bert.BertAttention'>", 'self_encoder_layer_0_attention_self': "<class 'transformers.models.bert.modeling_bert.BertSelfAttention'>", 'self_encoder_layer_0_attention_self_query': "<class 'torch.nn.modules.linear.Linear'>"}
  File "/usr/local/lib/python3.9/dist-packages/transformers/models/bert/modeling_bert.py", line 285, in forward
    mixed_query_layer = self.query(hidden_states)
 |   File "/usr/local/lib/python3.9/dist-packages/transformers/models/bert/modeling_bert.py", line 426, in forward
    self_outputs = self.self(
 |   File "/usr/local/lib/python3.9/dist-packages/transformers/models/bert/modeling_bert.py", line 496, in forward
    self_attention_outputs = self.attention(
 |   File "/usr/local/lib/python3.9/dist-packages/transformers/models/bert/modeling_bert.py", line 610, in forward
    layer_outputs = layer_module(
 |   File "/usr/local/lib/python3.9/dist-packages/transformers/models/bert/modeling_bert.py", line 1021, in forward
    encoder_outputs = self.encoder(

Set torch._dynamo.config.verbose=True for more information

You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True

Minified repro

"""
bert
"""

import torch
import numpy as np
from transformers import BertTokenizer, BertModel
from torch._inductor import config
config.triton.mm = "autotune"
#config.triton.mm = "triton"

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained("bert-base-uncased").to(device="cuda:0")#opt_model = torch.compile(model, backend="inductor") # This is the only line of code that we changed
#opt_model = torch.compile(model, passes={"triton-autotune":True})
#opt_model = torch.compile(model, passes={"triton-mm":"triton"})
#opt_model = torch.compile(model, passes={'triton-mm': "triton", 'triton-bmm': True}) # this also fails with a different error message

def measure_perf(token, verbose=False, N=8):
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    ref_runtime = list()
    print(f"sequnece length = {token['input_ids'].shape[1]}")
    for i in range(N):
        start_event.record()
        output = model(**token)
        end_event.record()
        torch.cuda.synchronize()
        estimate_ms = start_event.elapsed_time(end_event)
        if verbose:
            print(f"model: estimated_ms={estimate_ms}")
        if i>0 :
            ref_runtime.append(estimate_ms)

    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    triton_runtime = list()
    for i in range(N):
        start_event.record()
        output = opt_model(**token)
        end_event.record()
        torch.cuda.synchronize()
        estimate_ms = start_event.elapsed_time(end_event)
        if verbose:
            print(f"opt_model: estimated_ms={estimate_ms}")
        if i>0 :
            triton_runtime.append(estimate_ms)

    a = np.mean(ref_runtime)
    b = np.mean(triton_runtime)
    gain = a/b
    print(f"model: estimated_ms={a}")
    print(f"triton model: estimated_ms={b}, gain={gain}")
    return gain

text = "Replace me by any text you'd like."
t = tokenizer(text, return_tensors='pt').to(device="cuda:0")

seq_len = t['input_ids'].shape[1]
print(f"gains={seq_len}:{measure_perf(t)}")
soumith commented 1 year ago

this will be fixed after https://github.com/pytorch/pytorch/pull/90738 re-lands