triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
12.78k stars 1.54k forks source link

Low performance on RX 6800XT / Navi21 #3704

Open 2440020096 opened 5 months ago

2440020096 commented 5 months ago

When running 03-matrix-multiply the performance is much lower compared to rocBLAS

        M       N        K    rocBLAS    Triton
0  1024.0  1024.0   1024.0  21.770480  3.301941
1  2048.0  2048.0   2048.0  25.513268  3.196135
2  4096.0  4096.0   4096.0  22.292006  3.373854
3  8192.0  8192.0   8192.0  23.191270  3.307897
4  9728.0  8192.0  65536.0  18.716197  3.197035

Additionally, when trying to run 06-fused-attention, it fails with this error:

Traceback (most recent call last):
  File "/home/k/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1142, in ast_to_ttir
    generator.visit(fn.parse())
  File "/home/k/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1025, in visit
    ret = super().visit(node)
  File "/usr/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/k/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 292, in visit_Module
    ast.NodeVisitor.generic_visit(self, node)
  File "/usr/lib/python3.10/ast.py", line 426, in generic_visit
    self.visit(item)
  File "/home/k/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1025, in visit
    ret = super().visit(node)
  File "/usr/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/k/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 362, in visit_FunctionDef
    self.visit_compound_statement(node.body)
  File "/home/k/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 287, in visit_compound_statement
    ret_type = self.visit(stmt)
  File "/home/k/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1025, in visit
    ret = super().visit(node)
  File "/usr/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/k/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 414, in visit_Assign
    values = self.visit(node.value)
  File "/home/k/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1025, in visit
    ret = super().visit(node)
  File "/usr/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/k/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 954, in visit_Call
    return fn(*args, **extra_kwargs, **kws)
  File "/home/k/.local/lib/python3.10/site-packages/triton/language/core.py", line 29, in wrapper
    return fn(*args, **kwargs)
  File "/home/k/.local/lib/python3.10/site-packages/triton/language/core.py", line 1132, in make_block_ptr
    return semantic.make_block_ptr(base, shape, strides, offsets, block_shape, order, _builder)
  File "/home/k/.local/lib/python3.10/site-packages/triton/language/semantic.py", line 1662, in make_block_ptr
    assert sorted(order) == list(range(len(order))), "Expected a permutation of (0, 1, ..., len(order)-1) in order"
AssertionError: Expected a permutation of (0, 1, ..., len(order)-1) in order

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/media/k/SSD-too/triton/06-fused-attention.py", line 664, in <module>
    bench_flash_attention.run(save_path=".", print_data=True)
  File "/home/k/.local/lib/python3.10/site-packages/triton/testing.py", line 338, in run
    self._run(bench, save_path, show_plots, print_data)
  File "/home/k/.local/lib/python3.10/site-packages/triton/testing.py", line 290, in _run
    ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args)
  File "/media/k/SSD-too/triton/06-fused-attention.py", line 644, in bench_flash_attention
    ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
  File "/home/k/.local/lib/python3.10/site-packages/triton/testing.py", line 106, in do_bench
    fn()
  File "/media/k/SSD-too/triton/06-fused-attention.py", line 639, in <lambda>
    fn = lambda: attention(q, k, v, causal, sm_scale)
  File "/home/k/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 569, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/media/k/SSD-too/triton/06-fused-attention.py", line 474, in forward
    _attn_fwd[grid](
  File "<string>", line 74, in _attn_fwd
  File "/home/k/.local/lib/python3.10/site-packages/triton/compiler/compiler.py", line 552, in compile
    next_module = compile_kernel(module)
  File "/home/k/.local/lib/python3.10/site-packages/triton/compiler/compiler.py", line 427, in <lambda>
    lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug, arch=arch), arch))
  File "/home/k/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1151, in ast_to_ttir
    raise CompilationError(fn.src, node, repr(e)) from e
triton.compiler.errors.CompilationError: at 35:14:        offsets=(start_m * BLOCK_M, 0),
        block_shape=(BLOCK_M, BLOCK_DMODEL),
        order=(1, 0),
    )
    v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1, 0)
    V_block_ptr = tl.make_block_ptr(
        base=V + qvk_offset,
        shape=(N_CTX, BLOCK_DMODEL),
        strides=(stride_vk, stride_vn),
        offsets=(0, 0),
        block_shape=(BLOCK_N, BLOCK_DMODEL),
        order=v_order,
              ^
AssertionError('Expected a permutation of (0, 1, ..., len(order)-1) in order')

Hardware: RX 6800 XT / Navi21 / gfx1030 Pytorch version: 2.4.0.dev20240405+rocm6.0 Triton version: 3.0.0+0a22a91d04, installed via nightly pytorch rocm 6.0 OS: Ubuntu 22.04.3 LTS

Paging @sunway513

sunway513 commented 5 months ago

Thanks @2440020096 for reporting the issue. The team has just completed the Navi31 support, and with this feature request we'll start to look into Navi21 support. We'll update when we have plan to share. cc @joviliast @zhanglx13

zhanglx13 commented 5 months ago

Thanks for reporting. I cannot assign this to Illia (due to github id issue) so I assigned myself. Will update here when we have some results.

supernovae commented 4 months ago

Thanks @2440020096 for reporting the issue. The team has just completed the Navi31 support, and with this feature request we'll start to look into Navi21 support. We'll update when we have plan to share. cc @joviliast @zhanglx13

This is fantastic news that Navi31 support is complete. Is there any documentation/release notes on this being available? i'd love to read more but finding out info on RDNA3 to see how the rest of the components such as memory efficient flash attention can fall into place for RDNA3 would be awesome.