Closed minminsun closed 1 year ago
Update: Upgraded to the latest commit 821e75a2b025c62e4fab0578e32a12b5ca5fc9e9, the performance gets much better. fused-attention-batch4-head48-d128-fwd-causal=False: N_CTX Triton 0 1024.0 69.581110 1 2048.0 73.569046 2 4096.0 75.318463 3 8192.0 76.932992 4 16384.0 76.648059
The performance with head_dim=128 is around 70% of head_dim=64. Maybe further improvement is still needed.
@minminsun For D=128, can you try this branch https://github.com/ROCmSoftwarePlatform/triton/tree/reduce_lds_usage
For D=128, we need larger tile size and some optimizations in the epilogue to reduce LDS usage. It should be soon merged in the triton-mlir branch.
@zhanglx13 Thanks for your response! But the performance gets lower after switching to branch reduce_lds_usage (commit a6db42d2a2ccecaee279efdfd98d1cb530f75049) on MI210. fused-attention-batch4-head48-d128-fwd-causal=False: N_CTX Triton 0 1024.0 44.790182 1 2048.0 47.149518 2 4096.0 48.370925 3 8192.0 49.587914 4 16384.0 49.441485
@minminsun This is unexpected. Let me ping someone to try it.
@alefimov-amd Can you try this on a MI200 or MI210 machine?
@minminsun Can you please double check if you rebuilt triton after switching to branch reduce_lds_usage
?
@zhanglx13 @minminsun
I've tried reduce_lds_usage
branch on mi210 and got this:
fused-attention-batch4-head48-d128-fwd-causal=False:
N_CTX Triton Flash-1
0 1024.0 85.164008 44.592183
1 2048.0 98.000198 50.076568
2 4096.0 94.663083 59.136411
3 8192.0 104.370936 60.812629
4 16384.0 105.686353 60.751002
On top of triton-mlir
branch I've also got slightly better results:
fused-attention-batch4-head48-d128-fwd-causal=False:
N_CTX Triton Flash-1
0 1024.0 73.678773 44.532794
1 2048.0 77.849647 54.586812
2 4096.0 80.303655 59.137811
3 8192.0 82.175344 60.307481
4 16384.0 82.501823 60.928475
@minminsun Could you share what changes did you do to tutorial script?
P.s. Could you also check different rocm versions. They could affect performance a little. I was using 5.4.0
Thanks @binarman !
Could you share what changes did you do to tutorial script?
I did no change to the tutorial scipt.
P.s. Could you also check different rocm versions. They could affect performance a little.
The version I use is rocm-5.2.3. This might be the reason of slow-down.
105 tflops for D=128 on MI210 seems reasonable. Thanks @binarman
Give it a try with a newer ROCm. Please re-open as needed.
Running tutorials/06-fused-attention.py on MI210, the performance with head_dim=64 looks good: fused-attention-batch4-head48-d64-fwd: N_CTX Triton 0 1024.0 92.771498 1 2048.0 100.215724 2 4096.0 102.864525 3 8192.0 104.837738 4 16384.0 105.736522
But the performance drops dramatically when head_dim changed to 128: fused-attention-batch4-head48-d128-fwd: N_CTX Triton 0 1024.0 12.889821 1 2048.0 13.136331 2 4096.0 13.232509 3 8192.0 13.355740 4 16384.0 13.383783
The performance with head_dim=128 is very important because it is the head_dim of the most popular LLM models.
Do you have idea of why and how can this be improved? Thanks!