Open jinsong-mao opened 5 months ago
Thanks for this great work,
There is some perf gap between AOT and JIT Triton for flash attention on most seqlen, n_heads, head_dim We tried to tune the flash attention kernel and got some perf improvement on head_dim=128, However, it's still slower than JIT Triton kernel.
Looks their triton kernel tune space has some difference and this is the main difference we found. triton kernel tune space for aotriton - https://github.com/ROCm/aotriton/blob/main/tritonsrc/attn_torch_function.py#L47 triton kernel tune space for jit triton - https://github.com/ROCm/triton/blob/triton-mlir/python/tutorials/06-fused-attention.py#L84-L92
Is there any other main difference that make JIT Triton faster than AOT triton FA kernel?
ubuntu 22
mi300
rocBLAS
Suggestion Description
Thanks for this great work,
There is some perf gap between AOT and JIT Triton for flash attention on most seqlen, n_heads, head_dim We tried to tune the flash attention kernel and got some perf improvement on head_dim=128, However, it's still slower than JIT Triton kernel.
Looks their triton kernel tune space has some difference and this is the main difference we found. triton kernel tune space for aotriton - https://github.com/ROCm/aotriton/blob/main/tritonsrc/attn_torch_function.py#L47 triton kernel tune space for jit triton - https://github.com/ROCm/triton/blob/triton-mlir/python/tutorials/06-fused-attention.py#L84-L92
Is there any other main difference that make JIT Triton faster than AOT triton FA kernel?
Operating System
ubuntu 22
GPU
mi300
ROCm Component
rocBLAS