Closed netw0rkf10w closed 3 months ago
I'm not aware that triton 3.0 is released.
@jataylo Could you chime in about this combination?
@zhanglx13 If I install a previous version of PyTorch (e.g., 2.1.2) that comes with Triton 2, then do I have to install this repo to have flash attention working? Thank you in advance for your answer.
Hey @netw0rkf10w @zhanglx13
3.0 version comes in for us to maintain the same triton version as the regular openai/nvidia triton wheel produced in pytorch, but it's actually using an earlier commit of the triton-mlir branch. We may want to rethink this versioning or consider upgrading the triton version number in our branch so this is less confusing.
On the flash attention repo, the question is whether this branch https://github.com/ROCm/triton/tree/pytorch_nightly/23_11_2023 supports it, which was cutoff from triton-mlir branch on Nov 3rd @zhanglx13.
@jataylo Thanks for the information.
FYI, I've tested with PyTorch nightly (pytorch-triton-rocm==3.0.0+dafe145982
) and it didn't work:
Traceback (most recent call last):
File "/home/user/code/benchmark_flash_attention.py", line 174, in <module>
ex()
File "/home/user/code/benchmark_flash_attention.py", line 106, in ex
r4_fp16 = flash_attention_triton(q.half(), k.half(), v.half())
File "/home/user/code/benchmark_flash_attention.py", line 70, in flash_attention_triton
tri_out, _ = attention(q, k, v, None, metadata)
File "/home/user/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 572, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/home/user/code/triton/python/perf-kernels/flash-attention.py", line 801, in forward
attn_fwd[grid](
File "<string>", line 74, in attn_fwd
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/compiler.py", line 552, in compile
next_module = compile_kernel(module)
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/compiler.py", line 427, in <lambda>
lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug, arch=arch), arch))
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1142, in ast_to_ttir
generator.visit(fn.parse())
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1025, in visit
ret = super().visit(node)
File "/opt/cray/pe/python/3.10.10/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 292, in visit_Module
ast.NodeVisitor.generic_visit(self, node)
File "/opt/cray/pe/python/3.10.10/lib/python3.10/ast.py", line 426, in generic_visit
self.visit(item)
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1025, in visit
ret = super().visit(node)
File "/opt/cray/pe/python/3.10.10/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 362, in visit_FunctionDef
self.visit_compound_statement(node.body)
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 287, in visit_compound_statement
ret_type = self.visit(stmt)
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1025, in visit
ret = super().visit(node)
File "/opt/cray/pe/python/3.10.10/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 414, in visit_Assign
values = self.visit(node.value)
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1025, in visit
ret = super().visit(node)
File "/opt/cray/pe/python/3.10.10/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 621, in visit_IfExp
raise UnsupportedLanguageConstruct(
triton.compiler.errors.UnsupportedLanguageConstruct: at 45:14: # small for all start_m so for those we return early.
if start_m * BLOCK_M > seqlen_q:
return
cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)
cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)
seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start
else:
cu_seqlens_q_start = 0
cu_seqlens_k_start = 0
seqlen_q = max_seqlens_q
seqlen_k = max_seqlens_k
off_h_k = off_h_q % hk if is_mqa else off_h_q
@netw0rkf10w Can you try install triton from this repo following the steps
git clone https://github.com/ROCm/triton.git
cd triton/python
pip install -e .
It should update Triton to the tot and work with pytorch 2.1, 2.2, and 2.3. Note that if you want AMD f8 data type support, it has to be pytorch 2.3.
@zhanglx13 I installed from this repo and now I got the following errors when running the code:
Traceback (most recent call last):
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1222, in ast_to_ttir
generator.visit(fn.parse())
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1104, in visit
ret = super().visit(node)
File "/opt/cray/pe/python/3.10.10/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 299, in visit_Module
ast.NodeVisitor.generic_visit(self, node)
File "/opt/cray/pe/python/3.10.10/lib/python3.10/ast.py", line 426, in generic_visit
self.visit(item)
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1104, in visit
ret = super().visit(node)
File "/opt/cray/pe/python/3.10.10/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 372, in visit_FunctionDef
self.visit_compound_statement(node.body)
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 294, in visit_compound_statement
ret_type = self.visit(stmt)
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1104, in visit
ret = super().visit(node)
File "/opt/cray/pe/python/3.10.10/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 641, in visit_If
self.visit_compound_statement(node.body)
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 294, in visit_compound_statement
ret_type = self.visit(stmt)
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1104, in visit
ret = super().visit(node)
File "/opt/cray/pe/python/3.10.10/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 628, in visit_If
self.visit_if_scf(cond, node)
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 597, in visit_if_scf
self.visit_then_else_blocks(node, liveins, then_block, else_block)
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 501, in visit_then_else_blocks
self.visit_compound_statement(node.body)
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 294, in visit_compound_statement
ret_type = self.visit(stmt)
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1104, in visit
ret = super().visit(node)
File "/opt/cray/pe/python/3.10.10/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 424, in visit_Assign
values = self.visit(node.value)
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1104, in visit
ret = super().visit(node)
File "/opt/cray/pe/python/3.10.10/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1020, in visit_Call
return self.call_JitFunction(fn, args, kws)
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 988, in call_JitFunction
generator.visit(fn.parse())
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1104, in visit
ret = super().visit(node)
File "/opt/cray/pe/python/3.10.10/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 299, in visit_Module
ast.NodeVisitor.generic_visit(self, node)
File "/opt/cray/pe/python/3.10.10/lib/python3.10/ast.py", line 426, in generic_visit
self.visit(item)
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1104, in visit
ret = super().visit(node)
File "/opt/cray/pe/python/3.10.10/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 372, in visit_FunctionDef
self.visit_compound_statement(node.body)
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 294, in visit_compound_statement
ret_type = self.visit(stmt)
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1104, in visit
ret = super().visit(node)
File "/opt/cray/pe/python/3.10.10/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 884, in visit_For
self.visit_compound_statement(node.body)
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 294, in visit_compound_statement
ret_type = self.visit(stmt)
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1104, in visit
ret = super().visit(node)
File "/opt/cray/pe/python/3.10.10/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 447, in visit_AugAssign
self.visit(assign)
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1104, in visit
ret = super().visit(node)
File "/opt/cray/pe/python/3.10.10/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 424, in visit_Assign
values = self.visit(node.value)
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1104, in visit
ret = super().visit(node)
File "/opt/cray/pe/python/3.10.10/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 476, in visit_BinOp
rhs = self.visit(node.right)
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1104, in visit
ret = super().visit(node)
File "/opt/cray/pe/python/3.10.10/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1026, in visit_Call
return fn(*args, **extra_kwargs, **kws)
File "/home/user/.local/lib/python3.10/site-packages/triton/language/core.py", line 27, in wrapper
return fn(*args, **kwargs)
File "/home/user/.local/lib/python3.10/site-packages/triton/language/core.py", line 1059, in dot
return semantic.dot(input, other, acc, allow_tf32, max_num_imprecise_acc, out_dtype, _builder)
File "/home/user/.local/lib/python3.10/site-packages/triton/language/semantic.py", line 1271, in dot
and rhs.shape[1].value >= 4, \
AssertionError: All values in both first input shape ([constexpr[256], constexpr[8]]) and second input shape ([constexpr[8], constexpr[64]]) must be >= 4!
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/user/code/benchmark_flash_attention.py", line 173, in <module>
ex()
File "/home/user/code/benchmark_flash_attention.py", line 106, in ex
r4_fp16 = flash_attention_triton(q.half(), k.half(), v.half())
File "/home/user/code/benchmark_flash_attention.py", line 70, in flash_attention_triton
tri_out, _ = attention(q, k, v, None, metadata)
File "/home/user/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 572, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/home/user/code/torching/src/torching/flash_attention/flash_attention_src.py", line 801, in forward
attn_fwd[grid](
File "/home/user/.local/lib/python3.10/site-packages/triton/runtime/jit.py", line 555, in run
self.cache[device][key] = compile(
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/compiler.py", line 591, in compile
next_module = compile_kernel(module)
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/compiler.py", line 455, in <lambda>
ast_to_ttir(src, signature, configs[0], constants, debug=debug, target=target), target))
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1231, in ast_to_ttir
raise CompilationError(fn.src, node, repr(e)) from e
triton.compiler.errors.CompilationError: at 145:16: if seqlen_k >= BLOCK_N:
acc, l_i, m_i = _attn_fwd_inner(
acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
start_m, seqlen_aligned,
dropout_p, philox_seed, batch_philox_offset, encoded_softmax_block_ptr,
BLOCK_M, BLOCK_DMODEL, BLOCK_N,
4 - STAGE, offs_m, offs_n,
PRE_LOAD_V,
False, seqlen_aligned,
bias_ptr,
ENABLE_DROPOUT,
RETURN_ENCODED_SOFTMAX
^
AssertionError('All values in both first input shape ([constexpr[256], constexpr[8]]) and second input shape ([constexpr[8], constexpr[64]]) must be >= 4!')
It seems the tile size is too small. It's doing (256x8) x (8x64) but lhs.shape[1] must be >= 16. It's hard to say if this is a bug. Can you post the script you are running?
Oh, increasing D_HEAD
from 8
to 16
works! I wonder why the head dimension has to be >=16, because there's no such constraint in the original CUDA implementation...
So apparently D_HEAD
should be high because otherwise it's even slower than the naive implementation.
Below are some benchmark results for batch_size = 16, sequence_length = 256
, num_heads in [8, 16, 32]
, and dims = [64, 128]
:
[----------------------------- FlashAttention -----------------------------]
| sdp_attention | flash | flash-triton
1 threads: -----------------------------------------------------------------
[torch.float16, 8, 64] | 274.9 | 142.5 | 317.6
[torch.float16, 8, 128] | 354.2 | 245.8 | 365.2
[torch.float16, 16, 64] | 412.4 | 200.2 | 355.3
[torch.float16, 16, 128] | 688.1 | 380.9 | 434.0
[torch.float16, 32, 64] | 790.3 | 375.4 | 243.9
[torch.float16, 32, 128] | 1356.3 | 743.4 | 325.8
[torch.bfloat16, 8, 64] | 350.7 | 117.2 | 388.1
[torch.bfloat16, 8, 128] | 396.5 | 255.8 | 431.1
[torch.bfloat16, 16, 64] | 548.1 | 196.6 | 453.4
[torch.bfloat16, 16, 128] | 771.7 | 384.4 | 540.1
[torch.bfloat16, 32, 64] | 1071.6 | 368.4 | 355.8
[torch.bfloat16, 32, 128] | 1519.3 | 744.2 | 497.0
Times are in microseconds (us).
The triton version is faster than the naive implementation only starting from 16 heads with dimension >= 64, and it's only faster than the other implementation (https://github.com/ROCm/flash-attention) starting from 32 heads with dimension >= 64 (the speedup in the case 32, 128) is very impressive though).
Could you tell me if these results correspond to what you have in mind regarding these implementations?
@netw0rkf10w
I'll assume [torch.float16, 8, 64]
means 8 heads with head_dim = 64.
Here are my observations from your benchmark, without knowing the GPU (seems this is on MI200 GPUs) you are using and block size and/or other tuning parameters.
heads*batch_size*seq_len/BLOCK_M
. If you have BLOCK_M=256, then #workgroup is 128 with 8 heads. The for MI200, there are about 110 CUs. So each CU only get 1 workgroup, which is very hard for the hw to find parallelism. With 16 heads, there are 256 workgroup. Each CU has 2 to 3 workgroups, which is usually the minimum requirement to fully utilize the CU.Hi @zhanglx13. Sorry for the late reply, my server has been broken over the last days so I've been waiting for it to be fixed before giving some more feedback.
First of all, your last commit (https://github.com/ROCm/triton/commit/35edd6a650e3f6a56e3c2db9a54fdc2f8a6505e1) leads to numerically incorrect results. I had to revert back to before that commit.
Second, below are some benchmarking results on an MI250x:
[------------------------------------------ FlashAttention ------------------------------------------]
| sdp_attention | torch sdpa | flash | triton-rocm
1 threads: -------------------------------------------------------------------------------------------
[torch.float16, 12, 64, 256, 64] | 872.7 | 741.4 | 234.8 | 510.6
[torch.float16, 12, 64, 256, 128] | 1645.0 | 1463.1 | 466.6 | 509.3
[torch.float16, 16, 64, 256, 64] | 1105.7 | 985.4 | 314.6 | 605.1
[torch.float16, 16, 64, 256, 128] | 2192.0 | 1948.7 | 628.1 | 683.3
[torch.float16, 16, 64, 576, 64] | 6016.5 | 5045.2 | 1864.9 | 1771.6
[torch.float16, 16, 64, 576, 128] | 12003.2 | 10068.3 | 3715.1 | 3533.0
[torch.float16, 16, 128, 256, 64] | 2246.9 | 2251.9 | 708.6 | 876.5
[torch.float16, 16, 128, 256, 128] | 4476.0 | 4481.0 | 1390.7 | 1173.0
[torch.float16, 16, 128, 784, 64] | 18295.2 | 16747.2 | 6723.6 | 5801.8
[torch.float16, 16, 128, 784, 128] | 36606.3 | 33475.4 | 13345.4 | 11521.5
[torch.bfloat16, 12, 64, 256, 64] | 1022.0 | 883.2 | 224.9 | 754.5
[torch.bfloat16, 12, 64, 256, 128] | 1934.5 | 1743.9 | 438.5 | 1006.2
[torch.bfloat16, 16, 64, 256, 64] | 1296.5 | 1170.6 | 299.1 | 956.1
[torch.bfloat16, 16, 64, 256, 128] | 2574.3 | 2320.5 | 580.9 | 1357.3
[torch.bfloat16, 16, 64, 576, 64] | 6028.7 | 4973.1 | 1771.1 | 2963.7
[torch.bfloat16, 16, 64, 576, 128] | 12071.5 | 10050.5 | 3492.6 | 5946.5
[torch.bfloat16, 16, 128, 256, 64] | 2197.3 | 2200.5 | 690.4 | 1244.4
[torch.bfloat16, 16, 128, 256, 128] | 4382.2 | 4385.6 | 1345.2 | 1910.6
[torch.bfloat16, 16, 128, 784, 64] | 16343.5 | 14736.8 | 6560.7 | 10183.0
[torch.bfloat16, 16, 128, 784, 128] | 32673.9 | 29445.0 | 13011.1 | 20128.2
The first columns correspond to the variables [dtype, num_heads, head_dim, seq_len, batch_size]
. The above values of num_heads, head_dim, seq_len
correspond roughly to the configurations of ViT base, large, and huge (the head dimension of ViT-huge is 80, which is not supported by flash-attn
, so I used 128 instead).
PyTorch was not compiled with flash attention or memory efficient attention (UserWarning: 1Torch was not compiled with flash attention. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:264.)
) so the torch sdpa
column should be there only for reference and do not reflect the true performance of torch.nn.functional.scaled_dot_product_attention
once the compilation issue has been fixed (by your colleague at AMD).
From the results, the triton version slightly outperforms the ROCm version (flash
) on long sequence lengths in FP16. And the ROCm version completely outperforms the triton version for BFloat16!
And moreover, for your information, the above results are way way behind those on an NVIDIA A100:
[--------------------------------------- FlashAttention ---------------------------------------]
| sdp_attention | torch sdpa | flash | triton
1 threads: -------------------------------------------------------------------------------------
[torch.float16, 12, 64, 256, 64] | 518.2 | 98.1 | 119.9 | 264.7
[torch.float16, 12, 64, 256, 128] | 940.1 | 185.7 | 178.1 | 202.4
[torch.float16, 16, 64, 256, 64] | 639.5 | 127.9 | 121.5 | 289.8
[torch.float16, 16, 64, 256, 128] | 1240.0 | 244.5 | 233.3 | 266.9
[torch.float16, 16, 64, 576, 64] | 3374.3 | 695.5 | 641.5 | 744.5
[torch.float16, 16, 64, 576, 128] | 6707.5 | 1380.6 | 1272.6 | 1229.3
[torch.float16, 16, 128, 256, 64] | 735.9 | 242.1 | 230.7 | 416.9
[torch.float16, 16, 128, 256, 128] | 1431.7 | 453.4 | 412.4 | 526.6
[torch.float16, 16, 128, 784, 64] | 6392.1 | 2041.1 | 1915.0 | 2397.9
[torch.float16, 16, 128, 784, 128] | 12817.3 | 4054.9 | 3803.6 | 4520.8
[torch.bfloat16, 12, 64, 256, 64] | 493.7 | 97.7 | 92.7 |
[torch.bfloat16, 12, 64, 256, 128] | 950.7 | 183.3 | 175.2 |
[torch.bfloat16, 16, 64, 256, 64] | 645.7 | 125.7 | 119.4 |
[torch.bfloat16, 16, 64, 256, 128] | 1253.9 | 240.0 | 229.1 |
[torch.bfloat16, 16, 64, 576, 64] | 3143.5 | 686.3 | 632.4 |
[torch.bfloat16, 16, 64, 576, 128] | 6245.4 | 1360.7 | 1247.5 |
[torch.bfloat16, 16, 128, 256, 64] | 739.7 | 228.6 | 207.6 |
[torch.bfloat16, 16, 128, 256, 128] | 1435.2 | 442.5 | 402.7 |
[torch.bfloat16, 16, 128, 784, 64] | 6451.8 | 2009.5 | 1875.7 |
[torch.bfloat16, 16, 128, 784, 128] | 12877.8 | 3987.9 | 3732.9 |
@netw0rkf10w I tried to reproduce the perf numbers with MI250X and here is what I got
case | sdp_attention | torch sdpa | flash | triton-us | triton-us tuned | |
---|---|---|---|---|---|---|
1 | [torch.float16, 12, 64, 256, 64] | 872.7 | 741.4 | 234.8 | 510.6 | 176.4 |
1 | [torch.float16, 12, 64, 256, 128] | 1645 | 1463.1 | 466.6 | 509.3 | 322.4 |
1 | [torch.float16, 16, 64, 256, 64] | 1105.7 | 985.4 | 314.6 | 605.1 | 225.4 |
1 | [torch.float16, 16, 64, 256, 128] | 2192 | 1948.7 | 628.1 | 683.3 | 420.0 |
2 | [torch.float16, 16, 64, 576, 64] | 6016.5 | 5045.2 | 1864.9 | 1771.6 | 1504.8 |
2 | [torch.float16, 16, 64, 576, 128] | 12003.2 | 10068.3 | 3715.1 | 3533 | 2998.4 |
3 | [torch.float16, 16, 128, 256, 64] | 2246.9 | 2251.9 | 708.6 | 876.5 | 557.2 |
3 | [torch.float16, 16, 128, 256, 128] | 4476 | 4481 | 1390.7 | 1173 | 1070.4 |
4 | [torch.float16, 16, 128, 784, 64] | 18295.2 | 16747.2 | 6723.6 | 5801.8 | |
4 | [torch.float16, 16, 128, 784, 128] | 36606.3 | 33475.4 | 13345.4 | 11521.5 |
As you can see, some tuning is needed to bring up the perf.
export OPTIMIZE_EPILOGUE=1
For bfloat16, there is a known issue with conversion between f32 and bf16 in the kv-loop, hence the perf for bf16 is worse. We are working on this issue.
For numerical errors with my last commit (https://github.com/ROCm/triton/commit/35edd6a650e3f6a56e3c2db9a54fdc2f8a6505e1), I tried some configs and it works. However, it's possible that some other configs fail since that PR is primarily for gemm kernels and was not extensively tested for FA. Could you provide the failed configs in FA so that I can take a look?
Hi @zhanglx13!
Thanks a lot for your reply! The results with tuning look very promising!
I'm trying to get some more results before replying, but I encountered an issue that I haven't been able to resolve:
Traceback (most recent call last):
File "/home/user/code/tests/benchmark_flash_attention.py", line 271, in <module>
ex()
File "/home/user/code/tests/benchmark_flash_attention.py", line 163, in ex
triton_rocm_fp16 = flash_attention_triton_rocm(q.half(), k.half(), v.half())
File "/home/user/code/tests/benchmark_flash_attention.py", line 113, in flash_attention_triton_rocm
output = flash_attn_triton_rocm(q, k, v)
File "/home/user/code/torching/src/torching/flash_attention/triton_rocm/flash_attention.py", line 916, in attention
output, _ = _attention.apply(q, k, v, None, metadata)
File "/home/user/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 553, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/home/user/code/torching/src/torching/flash_attention/triton_rocm/flash_attention.py", line 805, in forward
attn_fwd[grid](
File "/home/user/.local/lib/python3.10/site-packages/triton/runtime/jit.py", line 581, in run
bin.c_wrapper(
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/compiler.py", line 745, in __getattribute__
self._init_handles()
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/compiler.py", line 736, in _init_handles
mod, func, n_regs, n_spills = fn_load_binary(self.metadata["name"], self.asm[bin_path], self.shared, device)
SystemError: <built-in function load_binary> returned NULL without setting an exception
It's happened to me quite often, and each time I had to remove and re-install everything again for it to work. This has become increasing frustrating to the point that I wanted to give up Triton altogether. Do you happen you know how to fix this issue please?
I understand it's very frustrating and I think it happened to me once as well.
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/compiler.py
It looks like you are not using the triton installed from this repo. Here is what you can try
pip uninstall triton
cd python
. This could make a difference since this is where the setup.py lives.pip install -e .
If you still see such error that can be fixed by re-installing everything, let us know, and @jataylo from Pytorch can have more insight about the issue.
File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/compiler.py
It looks like you are not using the triton installed from this repo. Here is what you can try
Sorry that was because I tried both pip install -e . --user
and pip install . --user
, with the latter being tried after the former had failed to fix the issue. I only used this repo for the installation. By the way, pip uninstall triton
never worked for me:
$ pip uninstall triton
Found existing installation: triton 2.1.0
Can't uninstall 'triton'. No files were found to uninstall.
For uninstallation, I had to do:
rm -rf ~/.local/lib/python3.10/site-packages/triton*
rm ~/.local/lib/python3.10/site-packages/__editable__*triton*
rm -rf ~/code/triton/build ~/code/triton/python/triton.egg-info
But now uninstalling and re-installing Triton doesn't work anymore, the error remains. I don't know why :(
Hmm ok, so the last re-installation didn't work because I forgot to git checkout d6f14d374e0a02ce113687954accbb3526b3324f
, that is, to remove your last commit. Now it's working again!
What annoys me is that now it has numerical issues:
B = 5; H = 4; L = 10; D = 16
max diff r1_fp32 vs r1_fp16 = 0.0014356374740600586
max diff r1_fp32 vs r1_bf16 = 0.015220880508422852
max diff r1_fp32 vs r2_fp32 = 0.0
max diff r1_fp32 vs r2_fp16 = 0.0014356374740600586
max diff r1_fp32 vs r2_bf16 = 0.015220880508422852
max diff r1_fp32 vs r3_fp16 = 0.001074075698852539
max diff r1_fp32 vs r3_bf16 = 0.010561943054199219
max diff r1_fp32 vs triton_rocm_fp16 = 4.576810836791992
max diff r1_fp32 vs triton_rocm_bf16 = 4.579740524291992
This didn't happen before (it only happened with your most recent commit, as I mentioned earlier).
For larger sizes the results seem good:
For B = 64; H = 12; L = 256; D = 64
:
max diff r1_fp32 vs r1_fp16 = 0.0020145773887634277
max diff r1_fp32 vs r1_bf16 = 0.01580953598022461
max diff r1_fp32 vs r2_fp32 = 8.344650268554688e-07
max diff r1_fp32 vs r2_fp16 = 0.0019515752792358398
max diff r1_fp32 vs r2_bf16 = 0.013421595096588135
max diff r1_fp32 vs r3_fp16 = 0.0009750127792358398
max diff r1_fp32 vs r3_bf16 = 0.00976407527923584
max diff r1_fp32 vs triton_rocm_fp16 = 0.0013276338577270508
max diff r1_fp32 vs triton_rocm_bf16 = 0.009515345096588135
Full script for reproducing:
import sys, os
import torch
import math
import torch.utils.benchmark as benchmark
from torch.nn.functional import scaled_dot_product_attention
import warnings
if not torch.cuda.is_available():
warnings.warn("CUDA not available, Flash attention will not be used.")
HAS_FLASH = False
HAS_FLASH_TRITON_CUDA = False
HAS_FLASH_TRITON_ROCM = False
else:
if torch.cuda.get_device_capability('cuda') < (7, 5):
warnings.warn("GPU not supported, Flash attention will not be used.")
HAS_FLASH = False
HAS_FLASH_TRITON_CUDA = False
HAS_FLASH_TRITON_CUDA_DAO = False
HAS_FLASH_TRITON_ROCM = False
else:
try:
from flash_attn import flash_attn_func
HAS_FLASH = True
except ImportError:
warnings.warn("Cannot import flash_attn")
HAS_FLASH = False
try:
from torching.flash_attention import flash_attn_triton
HAS_FLASH_TRITON_CUDA = True
except ImportError:
warnings.warn("Cannot import flash_attn_triton_dao")
HAS_FLASH_TRITON_CUDA = False
try:
from torching.flash_attention import flash_attn_triton_dao
HAS_FLASH_TRITON_CUDA_DAO = True
except ImportError:
warnings.warn("Cannot import flash_attn_triton_dao")
HAS_FLASH_TRITON_CUDA_DAO = False
try:
from torching.flash_attention import flash_attn_triton_rocm
HAS_FLASH_TRITON_ROCM = True
except ImportError:
warnings.warn("Cannot import flash_attn_triton_rocm")
HAS_FLASH_TRITON_ROCM = False
def sdp_attention(q, k, v, rel_pos=None):
"""
q: (B, H, L, D)
"""
# with torch.cuda.amp.autocast(enabled=True, dtype=q.dtype):
scale = 1 / math.sqrt(q.size(-1))
attn = (q @ k.transpose(-2, -1)) * scale # (B, H, L, L)
if rel_pos is not None:
attn = attn + rel_pos
attn = attn.softmax(dim=-1) # (B, H, L, L)
# attn = self.attn_drop(attn)
result = attn @ v
# (B, H, L, L) x (B, H, L, D) --> (B, H, L, D) --> (B, L, H, D)
return result
def flash_attention(q, k, v):
"""
shape: (B, H, L, D)
"""
# flash attention expects (B, L, H, d)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
result = flash_attn_func(
q, k, v,
dropout_p=0,
softmax_scale=None,
causal=False,
return_attn_probs=False,
)
return result.transpose(1, 2)
def flash_attention_triton(q, k, v):
"""
shape: (B, H, seqlen_q, D_HEAD)
"""
output = flash_attn_triton(q, k, v)
return output
def flash_attention_triton_dao(q, k, v):
"""
shape: (B, H, seqlen_q, D_HEAD)
"""
output = flash_attn_triton_dao(q, k, v)
return output
def flash_attention_triton_rocm(q, k, v):
"""
shape: (B, H, seqlen_q, D_HEAD)
"""
output = flash_attn_triton_rocm(q, k, v)
return output
def forward_backward(q, k, v, method=''):
"""
forward + backward
q: (B, H, L, D)
"""
# forward
output = sdp_attention(q, k, v)
# backward
grad = torch.rand(output.size()).cuda()
output.backward(grad)
return output
def ex():
# (B, H, L, D)
B = 5; H = 4; L = 10; D = 16
# B = 5; H = 64; L = 128; D = 128
# B = 64; H = 12; L = 256; D = 64
q = torch.randn((B, H, L, D)).to('cuda')
k = torch.randn((B, H, L, D)).to('cuda')
v = torch.randn((B, H, L, D)).to('cuda')
# bias = torch.randn((1, H, L, L))
r1_fp32 = scaled_dot_product_attention(q, k, v)
r1_fp16 = scaled_dot_product_attention(q.half(), k.half(), v.half())
r1_bf16 = scaled_dot_product_attention(q.bfloat16(), k.bfloat16(), v.bfloat16())
print(f"max diff r1_fp32 vs r1_fp16 = {(r1_fp32 - r1_fp16).abs().max()}")
print(f"max diff r1_fp32 vs r1_bf16 = {(r1_fp32 - r1_bf16).abs().max()}")
r2_fp32 = sdp_attention(q, k, v)
r2_fp16 = sdp_attention(q.half(), k.half(), v.half())
r2_bf16 = sdp_attention(q.bfloat16(), k.bfloat16(), v.bfloat16())
print(f"max diff r1_fp32 vs r2_fp32 = {(r1_fp32 - r2_fp32).abs().max()}")
print(f"max diff r1_fp32 vs r2_fp16 = {(r1_fp32 - r2_fp16).abs().max()}")
print(f"max diff r1_fp32 vs r2_bf16 = {(r1_fp32 - r2_bf16).abs().max()}")
if HAS_FLASH:
# r3 = flash_attention(q, k, v)
r3_fp16 = flash_attention(q.half(), k.half(), v.half())
r3_bf16 = flash_attention(q.bfloat16(), k.bfloat16(), v.bfloat16())
print(f"max diff r1_fp32 vs r3_fp16 = {(r1_fp32 - r3_fp16).abs().max()}")
print(f"max diff r1_fp32 vs r3_bf16 = {(r1_fp32 - r3_bf16).abs().max()}")
# if HAS_FLASH_TRITON_CUDA:
# triton_cuda_fp16 = flash_attention_triton(q.half(), k.half(), v.half())
# print(f"max diff r1_fp32 vs triton_cuda_fp16 = {(r1_fp32 - triton_cuda_fp16).abs().max()}")
# # triton_cuda_bf16 = flash_attention_triton(q.bfloat16(), k.bfloat16(), v.bfloat16())
# # print(f"max diff r1_fp32 vs triton_cuda_bf16 = {(r1_fp32 - triton_cuda_bf16).abs().max()}")
# if HAS_FLASH_TRITON_CUDA_DAO:
# triton_cuda_dao_fp16 = flash_attention_triton_dao(q.half(), k.half(), v.half())
# print(f"max diff r1_fp32 vs triton_cuda_dao_fp16 = {(r1_fp32 - triton_cuda_dao_fp16).abs().max()}")
# triton_cuda_dao_bf16 = flash_attention_triton_dao(q.bfloat16(), k.bfloat16(), v.bfloat16())
# print(f"max diff r1_fp32 vs triton_cuda_dao_bf16 = {(r1_fp32 - triton_cuda_dao_bf16).abs().max()}")
if HAS_FLASH_TRITON_ROCM:
triton_rocm_fp16 = flash_attention_triton_rocm(q.half(), k.half(), v.half())
triton_rocm_bf16 = flash_attention_triton_rocm(q.bfloat16(), k.bfloat16(), v.bfloat16())
print(f"max diff r1_fp32 vs triton_rocm_fp16 = {(r1_fp32 - triton_rocm_fp16).abs().max()}")
print(f"max diff r1_fp32 vs triton_rocm_bf16 = {(r1_fp32 - triton_rocm_bf16).abs().max()}")
def main():
# ViT base: embed_dim=768, num_heads=12,
# patch=16, size=224: H=12, D=64, L=256
# ViT large: embed_dim=1024, num_heads=16:
# patch=14, size=224: H=16, D=64, L=256
# patch=16, size=384: H=16, D=64, L=576
# ViT huge: embed_dim=1280, num_heads=16:
# patch=16, size=224: H=16, D=80, L=256
# patch=16, size=448: H=16, D=80, L=784
batch_sizes = [64, 128]
# Compare takes a list of measurements which we'll save in results.
results = []
dyptes = [torch.float16, torch.bfloat16]
# heads = [12, 16]
# dims = [64, 80, 128]
# seq_lenghts = [256, 576, 784]
# We replace 80 with 128 because the triton version does not support 80
HDL = [(12, 64, 256), (16, 64, 256), (16, 64, 576), (16, 128, 256), (16, 128, 784)]
# HAS_FLASH_TRITON = False
for dtype in dyptes:
for H, D, L in HDL:
for B in batch_sizes:
q = torch.randn((B, H, L, D)).to('cuda').to(dtype)
k = torch.randn((B, H, L, D)).to('cuda').to(dtype)
v = torch.randn((B, H, L, D)).to('cuda').to(dtype)
# label and sub_label are the rows
# description is the column
label = 'FlashAttention'
sub_label = f'[{dtype}, {H}, {D}, {L}, {B}]'
results.append(benchmark.Timer(
stmt='sdp_attention(q, k, v)',
setup='from __main__ import sdp_attention',
globals={'q': q, 'k': k, 'v': v},
# num_threads=num_threads,
label=label,
sub_label=sub_label,
description='math',
).blocked_autorange(min_run_time=1))
results.append(benchmark.Timer(
stmt='scaled_dot_product_attention(q, k, v)',
setup='from torch.nn.functional import scaled_dot_product_attention',
globals={'q': q, 'k': k, 'v': v},
# num_threads=num_threads,
label=label,
sub_label=sub_label,
description='torch sdpa',
).blocked_autorange(min_run_time=1))
if dtype != torch.float32:
if HAS_FLASH:
results.append(benchmark.Timer(
stmt='flash_attention(q, k, v)',
setup='from __main__ import flash_attention',
globals={'q': q, 'k': k, 'v': v},
# num_threads=num_threads,
label=label,
sub_label=sub_label,
description='flash',
).blocked_autorange(min_run_time=1))
if HAS_FLASH_TRITON_CUDA and dtype != torch.bfloat16: # There's a bug with bfloat16
results.append(benchmark.Timer(
stmt='flash_attention_triton(q, k, v)',
setup='from __main__ import flash_attention_triton',
globals={'q': q, 'k': k, 'v': v},
# num_threads=num_threads,
label=label,
sub_label=sub_label,
description='triton-cuda',
).blocked_autorange(min_run_time=1))
if HAS_FLASH_TRITON_CUDA_DAO:
results.append(benchmark.Timer(
stmt='flash_attention_triton_dao(q, k, v)',
setup='from __main__ import flash_attention_triton_dao',
globals={'q': q, 'k': k, 'v': v},
# num_threads=num_threads,
label=label,
sub_label=sub_label,
description='triton-cuda-dao',
).blocked_autorange(min_run_time=1))
if HAS_FLASH_TRITON_ROCM:
results.append(benchmark.Timer(
stmt='flash_attention_triton_rocm(q, k, v)',
setup='from __main__ import flash_attention_triton_rocm',
globals={'q': q, 'k': k, 'v': v},
# num_threads=num_threads,
label=label,
sub_label=sub_label,
description='triton-rocm',
).blocked_autorange(min_run_time=1))
compare = benchmark.Compare(results)
compare.print()
if __name__ == '__main__':
# main()
ex()
I've tried training and obtained the following error:
Traceback (most recent call last):
File "/home/user/code/ViT/main_finetune.py", line 212, in <module>
main(args)
File "/home/user/code/ViT/main_finetune.py", line 205, in main
training_loop(args, model, device, data_loader_train, data_loader_val,
File "/home/user/code/ViT/common.py", line 423, in training_loop
train_stats = train_fn(args, model, device, data_loader_train,
File "/home/user/code/ViT/engine_finetune.py", line 90, in train_epoch
loss_scaler(loss, optimizer, clip_grad=max_norm,
File "/home/user/code/ViT/utils/optim.py", line 48, in __call__
self._scaler.scale(loss).backward(create_graph=create_graph)
File "/home/user/.local/lib/python3.10/site-packages/torch/_tensor.py", line 522, in backward
torch.autograd.backward(
File "/home/user/.local/lib/python3.10/site-packages/torch/autograd/__init__.py", line 266, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/home/user/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 289, in apply
return user_fn(self, *args)
File "/home/user/code/torching/src/torching/flash_attention/triton_rocm/flash_attention.py", line 846, in backward
assert do.is_contiguous()
AssertionError
What I did is to create a wrapper function like this:
def attention(q, k, v, causal=False, sm_scale=None, sequence_parallel=False):
if not sm_scale:
sm_scale = q.shape[-1] ** -0.5
metadata = MetaData(sm_scale=sm_scale)
metadata.max_seqlens_q = q.shape[-2]
metadata.max_seqlens_k = k.shape[-2]
metadata.causal = causal
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
output, _ = _attention.apply(q, k, v, None, metadata)
return output.contiguous()
I tried adding .contiguous()
everywhere as above, but this didn't help...
Update: Adding the following helped:
if not do.is_contiguous():
do = do.contiguous()
Lol, launching a multi-node training I obtained again SystemError: <built-in function load_binary> returned NULL without setting an exception
.
What annoys me is that now it has numerical issues:
I forgot to mention that the original triton implementation and @tridao's triton implementation also suffer from numerical incorrectness.
Some more updates. It seems that the triton implementation really shines in the backward pass. Below is the results for forward+backward:
| math | torch | flash | triton-rocm
[torch.float16, 12, 64, 256, 64] | 3.0 | 2.8 | 1.3 | 1.1
[torch.float16, 12, 64, 256, 128] | 5.9 | 5.5 | 2.6 | 1.3
[torch.float16, 16, 64, 256, 64] | 4.0 | 3.7 | 1.8 | 1.2
[torch.float16, 16, 64, 256, 128] | 7.9 | 7.4 | 3.5 | 1.7
[torch.float16, 16, 64, 576, 64] | 16.7 | 14.8 | 8.2 | 3.6
[torch.float16, 16, 64, 576, 128] | 33.4 | 29.4 | 16.2 | 7.0
[torch.float16, 16, 128, 256, 64] | 7.5 | 7.5 | 3.5 | 2.5
[torch.float16, 16, 128, 256, 128] | 14.9 | 14.9 | 6.9 | 3.7
[torch.float16, 16, 128, 784, 64] | 54.6 | 51.5 | 26.7 | 14.6
[torch.float16, 16, 128, 784, 128] | 109.1 | 102.9 | 52.9 | 27.2
[torch.bfloat16, 12, 64, 256, 64] | 3.4 | 3.2 | 1.6 | 1.3
[torch.bfloat16, 12, 64, 256, 128] | 6.7 | 6.3 | 3.1 | 1.9
[torch.bfloat16, 16, 64, 256, 64] | 4.5 | 4.2 | 2.1 | 1.6
[torch.bfloat16, 16, 64, 256, 128] | 8.9 | 8.4 | 4.1 | 2.5
[torch.bfloat16, 16, 64, 576, 64] | 16.6 | 14.6 | 10.2 | 5.0
[torch.bfloat16, 16, 64, 576, 128] | 33.3 | 29.1 | 20.2 | 9.8
[torch.bfloat16, 16, 128, 256, 64] | 7.5 | 7.5 | 4.3 | 3.0
[torch.bfloat16, 16, 128, 256, 128] | 15.0 | 15.0 | 8.5 | 4.7
[torch.bfloat16, 16, 128, 784, 64] | 57.5 | 54.2 | 34.6 | 17.9
[torch.bfloat16, 16, 128, 784, 128] | 114.7 | 108.2 | 68.7 | 33.6
Times are in milliseconds (ms).
Triton is often twice as fast as flash-attn
!!!
Unfortunately I was not able to use it for training because of the above SystemError
.
Unfortunately I was not able to use it for training because of the above
SystemError
.
It's finally working after I deleted ~/.triton
! Apparently it was some cache issue.
Unfortunately though, it didn't learn anything (the flat curve in the figure below). I officially gave up!
Now we are migrating to upstream. @netw0rkf10w Feel free to open new tickets regarding issues on AMD backend at https://github.com/openai/triton/issues
Hello,
I have installed PyTorch
2.3.0.dev20240211+rocm5.7
, which included Triton3.0.0+dafe145982
. Could you please tell me if we could use the flash attention implementation of this repo with that version of Triton? If we have to install the version of this repo, then will that break PyTorch's installation?Thank you very much in advance for your answer!