triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.54k stars 1.67k forks source link

[AMD] Use warp shuffle for MFMA to Dot operand layout conversion (FP8) #5139

Closed ilia-cher closed 1 week ago

ilia-cher commented 2 weeks ago

Adding a case for MFMA->Dot (FP8) layout conversion that avoids using shared memory, to speed up FP8 attention kernels. Test: lit test ctest -j32

ilia-cher commented 2 weeks ago

Some preliminary perf improvements (using prototype - can be tuned further so numbers can change) after:

fused-attention-fwd-d128-causal=False-float8_e4m3fnuz:
   BATCH     H    N_CTX      Triton
0   16.0  16.0   1024.0  320.352257
1    8.0  16.0   2048.0  440.529571
2    4.0  16.0   4096.0  556.378882
3    2.0  16.0   8192.0  612.439750
4    1.0  16.0  16384.0  636.701905
5    4.0  48.0   1024.0  309.433342
6    4.0  48.0   2048.0  436.186349
7    4.0  48.0   4096.0  566.117595
8    4.0  48.0   8192.0  630.003489
9    4.0  48.0  16384.0  659.517658

before:

fused-attention-fwd-d128-causal=False-float8_e4m3fnuz:
   BATCH     H    N_CTX      Triton
0   16.0  16.0   1024.0  285.783958
1    8.0  16.0   2048.0  410.978326
2    4.0  16.0   4096.0  499.566102
3    2.0  16.0   8192.0  556.077288
4    1.0  16.0  16384.0  581.175466
5    4.0  48.0   1024.0  270.836998
6    4.0  48.0   2048.0  409.327009
7    4.0  48.0   4096.0  513.261115
8    4.0  48.0   8192.0  573.631221
9    4.0  48.0  16384.0  606.831879
zhanglx13 commented 1 week ago

@ilia-cher I'd like to discuss two future PRs regarding chained_matmul_fp8

  1. Can we also support mfma_16 instructions?
  2. Is it possible to have dotOperand tile as one 32x32, and each thread holds 16 consecutive elements? In this way, we can do ds_read_b128 for the V tensor before the 2nd mamtul.
ilia-cher commented 1 week ago

Addressed comments, incl. adding mfma_16 support cc @zhanglx13

ilia-cher commented 1 week ago

@ilia-cher I'd like to discuss two future PRs regarding chained_matmul_fp8

  1. Can we also support mfma_16 instructions?
  2. Is it possible to have dotOperand tile as one 32x32, and each thread holds 16 consecutive elements? In this way, we can do ds_read_b128 for the V tensor before the 2nd mamtul.

mfma_16 is now supported in this PR, re. kWidth=16 -- sg - as discussed let's do this in a follow up