Closed ilia-cher closed 1 week ago
Some preliminary perf improvements (using prototype - can be tuned further so numbers can change) after:
fused-attention-fwd-d128-causal=False-float8_e4m3fnuz:
BATCH H N_CTX Triton
0 16.0 16.0 1024.0 320.352257
1 8.0 16.0 2048.0 440.529571
2 4.0 16.0 4096.0 556.378882
3 2.0 16.0 8192.0 612.439750
4 1.0 16.0 16384.0 636.701905
5 4.0 48.0 1024.0 309.433342
6 4.0 48.0 2048.0 436.186349
7 4.0 48.0 4096.0 566.117595
8 4.0 48.0 8192.0 630.003489
9 4.0 48.0 16384.0 659.517658
before:
fused-attention-fwd-d128-causal=False-float8_e4m3fnuz:
BATCH H N_CTX Triton
0 16.0 16.0 1024.0 285.783958
1 8.0 16.0 2048.0 410.978326
2 4.0 16.0 4096.0 499.566102
3 2.0 16.0 8192.0 556.077288
4 1.0 16.0 16384.0 581.175466
5 4.0 48.0 1024.0 270.836998
6 4.0 48.0 2048.0 409.327009
7 4.0 48.0 4096.0 513.261115
8 4.0 48.0 8192.0 573.631221
9 4.0 48.0 16384.0 606.831879
@ilia-cher I'd like to discuss two future PRs regarding chained_matmul_fp8
Addressed comments, incl. adding mfma_16 support cc @zhanglx13
@ilia-cher I'd like to discuss two future PRs regarding chained_matmul_fp8
- Can we also support mfma_16 instructions?
- Is it possible to have dotOperand tile as one 32x32, and each thread holds 16 consecutive elements? In this way, we can do ds_read_b128 for the V tensor before the 2nd mamtul.
mfma_16 is now supported in this PR, re. kWidth=16 -- sg - as discussed let's do this in a follow up
Adding a case for MFMA->Dot (FP8) layout conversion that avoids using shared memory, to speed up FP8 attention kernels. Test: lit test ctest -j32
[x] I am not making a trivial change, such as fixing a typo in a comment.
[x] I have written a PR description following these rules.
[x] I have run
pre-commit run --from-ref origin/main --to-ref HEAD
.Select one of the following.
/test
forlit
tests/unittest
for C++ tests/python/test
for end-to-end testsFILL THIS IN
.Select one of the following.
lit
tests.lit
tests I have added follow these best practices, including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)