ROCm / aotriton

Ahead of Time (AOT) Triton Math Library
MIT License
42 stars 15 forks source link

[Issues]: The Gap between AOT and JIT Triton on Flash Attention kernel #34

Open jinsong-mao opened 5 months ago

jinsong-mao commented 5 months ago

Suggestion Description

Thanks for this great work,

There is some perf gap between AOT and JIT Triton for flash attention on most seqlen, n_heads, head_dim We tried to tune the flash attention kernel and got some perf improvement on head_dim=128, However, it's still slower than JIT Triton kernel.

Looks their triton kernel tune space has some difference and this is the main difference we found. triton kernel tune space for aotriton - https://github.com/ROCm/aotriton/blob/main/tritonsrc/attn_torch_function.py#L47 triton kernel tune space for jit triton - https://github.com/ROCm/triton/blob/triton-mlir/python/tutorials/06-fused-attention.py#L84-L92

Is there any other main difference that make JIT Triton faster than AOT triton FA kernel?

Operating System

ubuntu 22

GPU

mi300

ROCm Component

rocBLAS