ROCm / triton

Development repository for the Triton language and compiler
MIT License
92 stars 29 forks source link

Low performance of fmha with head_dim=128 #365

Closed minminsun closed 1 year ago

minminsun commented 1 year ago

Running tutorials/06-fused-attention.py on MI210, the performance with head_dim=64 looks good: fused-attention-batch4-head48-d64-fwd: N_CTX Triton 0 1024.0 92.771498 1 2048.0 100.215724 2 4096.0 102.864525 3 8192.0 104.837738 4 16384.0 105.736522

But the performance drops dramatically when head_dim changed to 128: fused-attention-batch4-head48-d128-fwd: N_CTX Triton 0 1024.0 12.889821 1 2048.0 13.136331 2 4096.0 13.232509 3 8192.0 13.355740 4 16384.0 13.383783

The performance with head_dim=128 is very important because it is the head_dim of the most popular LLM models.

Do you have idea of why and how can this be improved? Thanks!

minminsun commented 1 year ago

Update: Upgraded to the latest commit 821e75a2b025c62e4fab0578e32a12b5ca5fc9e9, the performance gets much better. fused-attention-batch4-head48-d128-fwd-causal=False: N_CTX Triton 0 1024.0 69.581110 1 2048.0 73.569046 2 4096.0 75.318463 3 8192.0 76.932992 4 16384.0 76.648059

The performance with head_dim=128 is around 70% of head_dim=64. Maybe further improvement is still needed.

zhanglx13 commented 1 year ago

@minminsun For D=128, can you try this branch https://github.com/ROCmSoftwarePlatform/triton/tree/reduce_lds_usage

For D=128, we need larger tile size and some optimizations in the epilogue to reduce LDS usage. It should be soon merged in the triton-mlir branch.

minminsun commented 1 year ago

@zhanglx13 Thanks for your response! But the performance gets lower after switching to branch reduce_lds_usage (commit a6db42d2a2ccecaee279efdfd98d1cb530f75049) on MI210. fused-attention-batch4-head48-d128-fwd-causal=False: N_CTX Triton 0 1024.0 44.790182 1 2048.0 47.149518 2 4096.0 48.370925 3 8192.0 49.587914 4 16384.0 49.441485

zhanglx13 commented 1 year ago

@minminsun This is unexpected. Let me ping someone to try it.

@alefimov-amd Can you try this on a MI200 or MI210 machine?

zhanglx13 commented 1 year ago

@minminsun Can you please double check if you rebuilt triton after switching to branch reduce_lds_usage?

binarman commented 1 year ago

@zhanglx13 @minminsun I've tried reduce_lds_usage branch on mi210 and got this:

fused-attention-batch4-head48-d128-fwd-causal=False:
     N_CTX      Triton    Flash-1
0   1024.0   85.164008  44.592183
1   2048.0   98.000198  50.076568
2   4096.0   94.663083  59.136411
3   8192.0  104.370936  60.812629
4  16384.0  105.686353  60.751002
full output fused-attention-batch4-head48-d64-fwd-causal=False: ``` N_CTX Triton Flash-1 0 1024.0 93.825532 41.027584 1 2048.0 97.536282 52.261408 2 4096.0 99.883993 56.738337 3 8192.0 101.486341 58.890442 4 16384.0 102.066147 59.661089 fused-attention-batch4-head48-d128-fwd-causal=False: N_CTX Triton Flash-1 0 1024.0 85.164008 44.592183 1 2048.0 98.000198 50.076568 2 4096.0 94.663083 59.136411 3 8192.0 104.370936 60.812629 4 16384.0 105.686353 60.751002 fused-attention-batch4-head48-d64-fwd-causal=True: N_CTX Triton Flash-1 0 1024.0 47.027273 21.017882 1 2048.0 65.191805 33.213573 2 4096.0 78.731786 44.097055 3 8192.0 86.568173 50.514463 4 16384.0 91.252226 53.661108 fused-attention-batch4-head48-d128-fwd-causal=True: N_CTX Triton Flash-1 0 1024.0 43.593958 22.609236 1 2048.0 52.018652 34.671581 2 4096.0 72.665392 45.141401 3 8192.0 89.742730 51.127007 4 16384.0 99.105659 53.838361 fused-attention-batch4-head48-d64-bwd-causal=True: N_CTX Triton Flash-1 0 1024.0 16.447319 11.595156 1 2048.0 21.611797 17.777656 2 4096.0 25.786854 23.892899 3 8192.0 28.394455 28.854413 4 16384.0 29.082691 31.457410 ```
binarman commented 1 year ago

On top of triton-mlir branch I've also got slightly better results:

fused-attention-batch4-head48-d128-fwd-causal=False:
     N_CTX     Triton    Flash-1
0   1024.0  73.678773  44.532794
1   2048.0  77.849647  54.586812
2   4096.0  80.303655  59.137811
3   8192.0  82.175344  60.307481
4  16384.0  82.501823  60.928475

@minminsun Could you share what changes did you do to tutorial script?

binarman commented 1 year ago

P.s. Could you also check different rocm versions. They could affect performance a little. I was using 5.4.0

minminsun commented 1 year ago

Thanks @binarman !

Could you share what changes did you do to tutorial script?

I did no change to the tutorial scipt.

P.s. Could you also check different rocm versions. They could affect performance a little.

The version I use is rocm-5.2.3. This might be the reason of slow-down.

zhanglx13 commented 1 year ago

105 tflops for D=128 on MI210 seems reasonable. Thanks @binarman

jayfurmanek commented 1 year ago

Give it a try with a newer ROCm. Please re-open as needed.