triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.1k stars 1.6k forks source link

[BUG Report] Triton error when use tutorials #1994

Open cccc0der opened 1 year ago

cccc0der commented 1 year ago

Hi, I'm new to triton and doing some pretrain work.

I tested the tutorials in triton, 01-vector-add.py, 02-fused-softmax, 04-low-memory-dropout, 05-layer-norm works fine but error occurs when I tried 03-matrix-multiplication and 06-fused-attention

env list

Python: 3.8.3
PyTorch:  1.13.1+cu117
Cuda: 1.17
GPU: T4
Triton Tutorials: 2.0.0
Triton Version: 2.0.0.dev20221202

I have also tested with torch 1.12\cuda1.16, it still not works, is this a GPU incompatible problem?

03-matmul error, I also met the Triton Error with flash_attn_triton

Traceback (most recent call last):
  File "<string>", line 21, in matmul_kernel
KeyError: ('2-.-0-.-0-7d1eb0d2fed8ff2032dccb99c2cc311a-2b0c5161c53c71b37ae20a9996ee4bb8-c1f92808b4e4644c1732e8338187ac87-42648570729a4835b21c1c18cebedbfe-12f7ac1ca211e037f62a7c0c323d9990-5c5e32ff210f3b7f56c98ca29917c25e-06f0df2d61979d629033f4a22eff5198-b9ae7213d41541f67843018d049e1f90-0dd03b0bd512a184b3512b278d9dfa59-d35ab04ae841e2714a253c523530b071', (torch.float16, torch.float16, torch.float16, 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32'), (128, 256, 64, 8, None), (True, True, True, (True, False), (True, False), (True, False), (True, False), (False, True), (True, False), (False, True), (True, False), (False, True)))

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "03-matrix-multiplication.py", line 290, in <module>
    triton_output = matmul(a, b, activation=None)
  File "03-matrix-multiplication.py", line 270, in matmul
    matmul_kernel[grid](
  File "/home/tysearch/.local/lib/python3.8/site-packages/triton/runtime/jit.py", line 106, in launcher
    return self.run(*args, grid=grid, **kwargs)
  File "/home/tysearch/.local/lib/python3.8/site-packages/triton/runtime/autotuner.py", line 86, in run
    return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
  File "<string>", line 43, in matmul_kernel
RuntimeError: Triton Error [CUDA]: invalid argument

06-fused-attention error:

Traceback (most recent call last):
  File "<string>", line 21, in _fwd_kernel
KeyError: ('2-.-0-.-0-7d1eb0d2fed8ff2032dccb99c2cc311a-2b0c5161c53c71b37ae20a9996ee4bb8-c1f92808b4e4644c1732e8338187ac87-42648570729a4835b21c1c18cebedbfe-12f7ac1ca211e037f62a7c0c323d9990-5c5e32ff210f3b7f56c98ca29917c25e-06f0df2d61979d629033f4a22eff5198-b9ae7213d41541f67843018d049e1f90-0dd03b0bd512a184b3512b278d9dfa59-d35ab04ae841e2714a253c523530b071', (torch.float16, torch.float16, torch.float16, 'fp32', torch.float32, torch.float32, torch.float16, 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32'), (128, 64, 128), (True, True, True, (False,), True, True, True, (True, False), (True, False), (True, False), (False, True), (True, False), (True, False), (True, False), (False, True), (True, False), (True, False), (True, False), (False, True), (True, False), (True, False), (True, False), (False, True), (False, False), (True, False), (True, False)))

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/tysearch/.local/lib/python3.8/site-packages/triton/compiler.py", line 838, in make_triton_ir
    generator.visit(fn.parse())
  File "/home/tysearch/.local/lib/python3.8/site-packages/triton/compiler.py", line 771, in visit
    return super().visit(node)
  File "/usr/local/python3/lib/python3.8/ast.py", line 363, in visit
    return visitor(node)
  File "/home/tysearch/.local/lib/python3.8/site-packages/triton/compiler.py", line 260, in visit_Module
    ast.NodeVisitor.generic_visit(self, node)
  File "/usr/local/python3/lib/python3.8/ast.py", line 371, in generic_visit
    self.visit(item)
  File "/home/tysearch/.local/lib/python3.8/site-packages/triton/compiler.py", line 771, in visit
    return super().visit(node)
  File "/usr/local/python3/lib/python3.8/ast.py", line 363, in visit
    return visitor(node)
  File "/home/tysearch/.local/lib/python3.8/site-packages/triton/compiler.py", line 320, in visit_FunctionDef
    has_ret = self.visit_compound_statement(node.body)
  File "/home/tysearch/.local/lib/python3.8/site-packages/triton/compiler.py", line 254, in visit_compound_statement
    self.last_ret = self.visit(stmt)
  File "/home/tysearch/.local/lib/python3.8/site-packages/triton/compiler.py", line 771, in visit
    return super().visit(node)
  File "/usr/local/python3/lib/python3.8/ast.py", line 363, in visit
    return visitor(node)
  File "/home/tysearch/.local/lib/python3.8/site-packages/triton/compiler.py", line 648, in visit_For
    self.visit_compound_statement(node.body)
  File "/home/tysearch/.local/lib/python3.8/site-packages/triton/compiler.py", line 254, in visit_compound_statement
    self.last_ret = self.visit(stmt)
  File "/home/tysearch/.local/lib/python3.8/site-packages/triton/compiler.py", line 771, in visit
    return super().visit(node)
  File "/usr/local/python3/lib/python3.8/ast.py", line 363, in visit
    return visitor(node)
  File "/home/tysearch/.local/lib/python3.8/site-packages/triton/compiler.py", line 395, in visit_AugAssign
    self.visit(assign)
  File "/home/tysearch/.local/lib/python3.8/site-packages/triton/compiler.py", line 771, in visit
    return super().visit(node)
  File "/usr/local/python3/lib/python3.8/ast.py", line 363, in visit
    return visitor(node)
  File "/home/tysearch/.local/lib/python3.8/site-packages/triton/compiler.py", line 367, in visit_Assign
    values = self.visit(node.value)
  File "/home/tysearch/.local/lib/python3.8/site-packages/triton/compiler.py", line 771, in visit
    return super().visit(node)
  File "/usr/local/python3/lib/python3.8/ast.py", line 363, in visit
    return visitor(node)
  File "/home/tysearch/.local/lib/python3.8/site-packages/triton/compiler.py", line 451, in visit_BinOp
    return getattr(lhs, fn)(rhs, _builder=self.builder)
  File "/home/tysearch/.local/lib/python3.8/site-packages/triton/language/core.py", line 46, in wrapper
    return fn(*args, **kwargs)
  File "/home/tysearch/.local/lib/python3.8/site-packages/triton/language/core.py", line 456, in __mul__
    return semantic.mul(self, other, _builder)
  File "/home/tysearch/.local/lib/python3.8/site-packages/triton/language/semantic.py", line 163, in mul
    input, other = binary_op_type_checking_impl(input, other, builder)
  File "/home/tysearch/.local/lib/python3.8/site-packages/triton/language/semantic.py", line 107, in binary_op_type_checking_impl
    lhs, rhs = broadcast_impl_value(lhs, rhs, builder)
  File "/home/tysearch/.local/lib/python3.8/site-packages/triton/language/semantic.py", line 524, in broadcast_impl_value
    raise ValueError("Cannot make_shape_compatible: blocks must have the same rank")
ValueError: Cannot make_shape_compatible: blocks must have the same rank

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "06-fused-attention.py", line 358, in <module>
    bench_flash_attention.run(save_path='.', print_data=True)
  File "/home/tysearch/.local/lib/python3.8/site-packages/triton/testing.py", line 314, in run
    self._run(bench, save_path, show_plots, print_data)
  File "/home/tysearch/.local/lib/python3.8/site-packages/triton/testing.py", line 269, in _run
    ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args)
  File "06-fused-attention.py", line 341, in bench_flash_attention
    ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
  File "/home/tysearch/.local/lib/python3.8/site-packages/triton/testing.py", line 140, in do_bench
    fn()
  File "06-fused-attention.py", line 336, in <lambda>
    fn = lambda: attention(q, k, v, sm_scale)
  File "06-fused-attention.py", line 213, in forward
    _fwd_kernel[grid](
  File "/home/tysearch/.local/lib/python3.8/site-packages/triton/runtime/jit.py", line 106, in launcher
    return self.run(*args, grid=grid, **kwargs)
  File "<string>", line 41, in _fwd_kernel
  File "/home/tysearch/.local/lib/python3.8/site-packages/triton/compiler.py", line 1256, in compile
    asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages,
  File "/home/tysearch/.local/lib/python3.8/site-packages/triton/compiler.py", line 892, in _compile
    module, _ = make_triton_ir(fn, signature, specialization, constants)
  File "/home/tysearch/.local/lib/python3.8/site-packages/triton/compiler.py", line 843, in make_triton_ir
    raise CompilationError(fn.src, node) from e
triton.compiler.CompilationError: at 49:13:
def _fwd_kernel(
    Q, K, V, sm_scale,
    L, M,
    Out,
    stride_qz, stride_qh, stride_qm, stride_qk,
    stride_kz, stride_kh, stride_kn, stride_kk,
    stride_vz, stride_vh, stride_vk, stride_vn,
    stride_oz, stride_oh, stride_om, stride_on,
    Z, H, N_CTX,
    BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
    BLOCK_N: tl.constexpr,
):
    start_m = tl.program_id(0)
    off_hz = tl.program_id(1)
    # initialize offsets
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = tl.arange(0, BLOCK_N)
    offs_d = tl.arange(0, BLOCK_DMODEL)
    off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
    off_k = off_hz * stride_qh + offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kk
    off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk
    # Initialize pointers to Q, K, V
    q_ptrs = Q + off_q
    k_ptrs = K + off_k
    v_ptrs = V + off_v
    # initialize pointer to m and l
    m_prev = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
    l_prev = tl.zeros([BLOCK_M], dtype=tl.float32)
    acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
    # load q: it will stay in SRAM throughout
    q = tl.load(q_ptrs)
    # loop over k, v and update accumulator
    for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
        # -- compute qk ----
        k = tl.load(k_ptrs)
        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
        qk += tl.dot(q, k)
        qk *= sm_scale
        qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
        # compute new m
        m_curr = tl.maximum(tl.max(qk, 1), m_prev)
        # correct old l
        l_prev *= tl.exp(m_prev - m_curr)
        # attention weights
        p = tl.exp(qk - m_curr[:, None])
        l_curr = tl.sum(p, 1) + l_prev
        # rescale operands of matmuls
        l_rcp = 1. / l_curr
        p *= l_rcp
             ^
cccc0der commented 1 year ago

I also reinstall triton with version 2.0.0, still error

03-matmul

Traceback (most recent call last):
  File "<string>", line 21, in matmul_kernel
KeyError: ('2-.-0-.-0-7d1eb0d2fed8ff2032dccb99c2cc311a-d6252949da17ceb5f3a278a70250af13-1af5134066c618146d2cd009138944a0-14de7de5c4da5794c8ca14e7e41a122d-3498c340fd4b6ee7805fd54b882a04f5-e1f133f98d04093da2078dfc51c36b72-b26258bf01f839199e39d64851821f26-b9ae7213d41541f67843018d049e1f90-d7c06e3b46e708006c15224aac7a1378-f585402118c8a136948ce0a49cfe122c', (torch.float16, torch.float16, torch.float16, 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32'), (128, 256, 64, 8, None), (True, True, True, (True, False), (True, False), (True, False), (True, False), (False, True), (True, False), (False, True), (True, False), (False, True)))

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "03-matrix-multiplication.py", line 290, in <module>
    triton_output = matmul(a, b, activation=None)
  File "03-matrix-multiplication.py", line 270, in matmul
    matmul_kernel[grid](
  File "/home/tysearch/.local/lib/python3.8/site-packages/triton/runtime/autotuner.py", line 90, in run
    return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
  File "<string>", line 43, in matmul_kernel
  File "/home/tysearch/.local/lib/python3.8/site-packages/triton/compiler.py", line 1678, in __getattribute__
    self._init_handles()
  File "/home/tysearch/.local/lib/python3.8/site-packages/triton/compiler.py", line 1670, in _init_handles
    raise OutOfResources(self.shared, max_shared, "shared memory")
triton.compiler.OutOfResources: out of resource: shared memory, Required: 147456, Hardware limit: 65536. Reducing block sizes or `num_stages` may help.

06-fused-attention

error: 'tt.reduce' op inferred type(s) 'tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #triton_gpu.mma<{versionMajor = 1, versionMinor = 0, warpsPerCTA = [4, 1]}>}>>' are incompatible with return type(s) of operation 'tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #triton_gpu.mma<{versionMajor = 1, versionMinor = 2, warpsPerCTA = [2, 2]}>}>>'
Traceback (most recent call last):
  File "<string>", line 21, in _fwd_kernel
KeyError: ('2-.-0-.-0-7d1eb0d2fed8ff2032dccb99c2cc311a-d6252949da17ceb5f3a278a70250af13-1af5134066c618146d2cd009138944a0-14de7de5c4da5794c8ca14e7e41a122d-3498c340fd4b6ee7805fd54b882a04f5-e1f133f98d04093da2078dfc51c36b72-b26258bf01f839199e39d64851821f26-b9ae7213d41541f67843018d049e1f90-d7c06e3b46e708006c15224aac7a1378-f585402118c8a136948ce0a49cfe122c', (torch.float16, torch.float16, torch.float16, 'fp32', torch.float32, torch.float32, torch.float16, 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32'), (128, 64, 128), (True, True, True, (False,), True, True, True, (True, False), (True, False), (True, False), (False, True), (True, False), (True, False), (True, False), (False, True), (True, False), (True, False), (True, False), (False, True), (True, False), (True, False), (True, False), (False, True), (False, False), (True, False), (True, False)))

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "06-fused-attention.py", line 358, in <module>
    bench_flash_attention.run(save_path='.', print_data=True)
  File "/home/tysearch/.local/lib/python3.8/site-packages/triton/testing.py", line 317, in run
    self._run(bench, save_path, show_plots, print_data)
  File "/home/tysearch/.local/lib/python3.8/site-packages/triton/testing.py", line 272, in _run
    ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args)
  File "06-fused-attention.py", line 341, in bench_flash_attention
    ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
  File "/home/tysearch/.local/lib/python3.8/site-packages/triton/testing.py", line 143, in do_bench
    fn()
  File "06-fused-attention.py", line 336, in <lambda>
    fn = lambda: attention(q, k, v, sm_scale)
  File "06-fused-attention.py", line 213, in forward
    _fwd_kernel[grid](
  File "<string>", line 41, in _fwd_kernel
  File "/home/tysearch/.local/lib/python3.8/site-packages/triton/compiler.py", line 1620, in compile
    next_module = compile(module)
  File "/home/tysearch/.local/lib/python3.8/site-packages/triton/compiler.py", line 1551, in <lambda>
    lambda src: ttir_to_ttgir(src, num_warps, num_stages, capability)),
  File "/home/tysearch/.local/lib/python3.8/site-packages/triton/compiler.py", line 992, in ttir_to_ttgir
    pm.run(mod)
RuntimeError: PassManager::run failed
cccc0der commented 1 year ago

I suppose the cause is T4-Turing not support num_stages?

num_stages – the number of stages that the compiler should use when software-pipelining loops. Mostly useful for matrix multiplication workloads on SM80+ GPUs.
cccc0der commented 1 year ago

It seems I can only receive one Meta-parameters, if I set more than 1 param with tl.constexpr type, KeyError will happen

@triton.jit
def matmul_kernel(
   ...
    # Meta-parameters
    BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
    GROUP_SIZE_M: tl.constexpr,
    ACTIVATION: tl.constexpr,
):