Closed zhangxiao-stack closed 1 month ago
I tried your code on MI250 with rocm5.6 and test_op_fwd
can pass after I change best_config = _attn_fwd.get_best_config(N_CTX = q.shape[2], STAGE = stage)
to best_config = _attn_fwd.get_best_config()
. bwd tests did not work.
It seems that this is a very old version of flash-attention kernel from the repo. Can you try the latest one on the tip of the triton-mlir branch?
@zhanglx13 thanks for your reply, I will try the latest one
@zhanglx13 I update the latest on ,the results seems ok, Additionally, if I integrate Triton Flash Attention with VLLM, which Triton FA code should I use?
Let's ask @vgokhale for help
@vgokhale thanks for your reply,I just tested this code(python/perf-kernels/flash-attention.py) and it generates core dumped
fused-attention-fwd-d128-causal=False-bias=False:
BATCH H N_CTX Triton
0 16.0 16.0 1024.0 61.429768
1 8.0 16.0 2048.0 68.134906
2 4.0 16.0 4096.0 71.954407
3 2.0 16.0 8192.0 73.787859
4 1.0 16.0 16384.0 74.633200
5 2.0 48.0 1024.0 58.086475
6 2.0 48.0 2048.0 66.287027
7 2.0 48.0 4096.0 71.869249
8 2.0 48.0 8192.0 73.561685
9 2.0 48.0 16384.0 73.963175
10 8.0 16.0 1989.0 64.129571
11 4.0 16.0 4097.0 65.608565
12 2.0 16.0 8122.0 72.668382
13 1.0 16.0 16281.0 73.702959
14 2.0 48.0 1021.0 55.679723
15 2.0 48.0 2001.0 62.954609
16 2.0 48.0 3996.0 68.704991
17 2.0 48.0 8181.0 73.365176
fused-attention-fwd-d128-causal=False-bias=True:
BATCH H N_CTX Triton
0 16.0 16.0 1024.0 36.203390
1 8.0 16.0 2048.0 46.371167
2 4.0 16.0 4096.0 52.490864
3 2.0 16.0 8192.0 55.326495
4 1.0 16.0 16384.0 56.978791
5 2.0 48.0 1024.0 31.577685
6 2.0 48.0 2048.0 43.559168
7 2.0 48.0 4096.0 52.591746
8 2.0 48.0 8192.0 55.379588
9 2.0 48.0 16384.0 57.056650
10 8.0 16.0 1989.0 38.848884
11 4.0 16.0 4097.0 44.244679
12 2.0 16.0 8122.0 53.214392
13 1.0 16.0 16281.0 55.554325
14 2.0 48.0 1021.0 27.077047
15 2.0 48.0 2001.0 36.845493
16 2.0 48.0 3996.0 47.236803
17 2.0 48.0 8181.0 54.186119
python3: /home/runner/work/triton/triton/llvm-project/mlir/lib/Analysis/SliceAnalysis.cpp:109: void getBackwardSliceImpl(mlir::Operation *, SetVector<mlir::Operation *> *, mlir::BackwardSliceOptions): Assertion `parentOp->getNumRegions() == 1 && parentOp->getRegion(0).getBlocks().size() == 1' failed.
Aborted (core dumped)
This is a bit weird - it completes the benchmark and core dumps somewhere after. These numbers are also pretty low, even for MI210. Few points:
1) Is it possible to try with a more recent ROCm version / docker image? 5.5 is quite old. 2) When you git fetch'ed triton-mlir to use latest, I assume you rebuilt Triton? If not, that would be good to do. 3) This version is a little bit in flux this week. If possible, can you wait while we push all updates? Once we push the new kernel (sometime this week), you'll be working with new code so fixing issues in this one, if any, may not make sense.
@vgokhale thanks for your reply, I will update the ROCM Version and try, and Is there a Flash-Decoding algorithm implemented based on Triton?
Hi @zhangxiao-stack,
flash-attention.py should now run without the core dump you faced above. Can you try with triton-mlir latest?
Re. flash-decoding, we are working on this. We expect a first version this week.
Hi, @vgokhale I git fetch latest triton-mlir and rebuilt triton, flash-attention.py run without the core dump, but the number are pretty low
fused-attention-fwd-d128-causal=False:
BATCH H N_CTX_Q N_CTX_K Triton
0 16.0 16.0 1024.0 1024.0 79.214163
1 8.0 16.0 2048.0 2048.0 86.808619
2 4.0 16.0 4096.0 4096.0 90.950733
3 2.0 16.0 8192.0 8192.0 93.136436
4 1.0 16.0 16384.0 16384.0 93.901165
5 2.0 48.0 1024.0 1024.0 75.596109
6 2.0 48.0 2048.0 1024.0 76.366666
7 2.0 48.0 4096.0 8192.0 92.684200
8 2.0 48.0 8192.0 4096.0 90.778724
9 2.0 48.0 16384.0 8192.0 92.719959
10 8.0 16.0 1989.0 15344.0 90.903315
11 4.0 16.0 4097.0 163.0 34.711483
12 2.0 16.0 8122.0 2159.0 85.837150
13 1.0 16.0 16281.0 7.0 2.098211
14 2.0 48.0 1021.0 1020.0 73.826180
15 2.0 48.0 2001.0 2048.0 81.380739
16 2.0 48.0 3996.0 9639.0 90.536361
17 2.0 48.0 8181.0 1021.0 78.771301
fused-attention-fwd-d128-causal=True:
BATCH H N_CTX_Q N_CTX_K Triton
0 16.0 16.0 1024.0 1024.0 12.886308
1 8.0 16.0 2048.0 2048.0 13.848956
2 4.0 16.0 4096.0 4096.0 14.383211
3 2.0 16.0 8192.0 8192.0 14.646250
4 1.0 16.0 16384.0 16384.0 14.777015
5 2.0 48.0 1024.0 1024.0 12.670611
6 2.0 48.0 2048.0 1024.0 25.142167
7 2.0 48.0 4096.0 8192.0 9.854398
8 2.0 48.0 8192.0 4096.0 28.748776
9 2.0 48.0 16384.0 8192.0 29.420537
10 8.0 16.0 1989.0 15344.0 7.728812
11 4.0 16.0 4097.0 163.0 38.106455
12 2.0 16.0 8122.0 2159.0 49.793807
13 1.0 16.0 16281.0 7.0 1.635808
14 2.0 48.0 1021.0 1020.0 12.503968
15 2.0 48.0 2001.0 2048.0 13.502583
16 2.0 48.0 3996.0 9639.0 9.313698
17 2.0 48.0 8181.0 1021.0 91.321226
fused-attention-varlen-fwd-d128:
BATCH HQ HK N_CTX Triton
0 2.0 16.0 4.0 1024.0 5.716963
1 8.0 16.0 2.0 2048.0 4.230196
2 4.0 16.0 8.0 4096.0 16.136856
3 2.0 16.0 4.0 8192.0 13.651264
4 2.0 16.0 8.0 16384.0 82.597363
5 2.0 48.0 12.0 1024.0 17.080724
6 2.0 48.0 24.0 2048.0 17.039837
7 2.0 48.0 8.0 4096.0 55.623457
8 2.0 48.0 4.0 8192.0 79.230780
9 2.0 48.0 2.0 16384.0 87.740621
10 2.0 64.0 32.0 1024.0 22.524289
11 4.0 64.0 16.0 2048.0 33.392685
12 4.0 64.0 8.0 4096.0 37.548845
13 4.0 64.0 32.0 8192.0 66.148101
14 4.0 128.0 16.0 16384.0 81.158887
Hi, @vgokhale I git fetch latest triton-mlir and rebuilt triton, flash-attention.py run without the core dump, but the number are pretty low
fused-attention-fwd-d128-causal=False:
BATCH H N_CTX_Q N_CTX_K Triton
0 16.0 16.0 1024.0 1024.0 79.214163
1 8.0 16.0 2048.0 2048.0 86.808619
2 4.0 16.0 4096.0 4096.0 90.950733
3 2.0 16.0 8192.0 8192.0 93.136436
4 1.0 16.0 16384.0 16384.0 93.901165
5 2.0 48.0 1024.0 1024.0 75.596109
6 2.0 48.0 2048.0 1024.0 76.366666
7 2.0 48.0 4096.0 8192.0 92.684200
8 2.0 48.0 8192.0 4096.0 90.778724
9 2.0 48.0 16384.0 8192.0 92.719959
10 8.0 16.0 1989.0 15344.0 90.903315
11 4.0 16.0 4097.0 163.0 34.711483
12 2.0 16.0 8122.0 2159.0 85.837150
13 1.0 16.0 16281.0 7.0 2.098211
14 2.0 48.0 1021.0 1020.0 73.826180
15 2.0 48.0 2001.0 2048.0 81.380739
16 2.0 48.0 3996.0 9639.0 90.536361
17 2.0 48.0 8181.0 1021.0 78.771301
fused-attention-fwd-d128-causal=True:
BATCH H N_CTX_Q N_CTX_K Triton
0 16.0 16.0 1024.0 1024.0 12.886308
1 8.0 16.0 2048.0 2048.0 13.848956
2 4.0 16.0 4096.0 4096.0 14.383211
3 2.0 16.0 8192.0 8192.0 14.646250
4 1.0 16.0 16384.0 16384.0 14.777015
5 2.0 48.0 1024.0 1024.0 12.670611
6 2.0 48.0 2048.0 1024.0 25.142167
7 2.0 48.0 4096.0 8192.0 9.854398
8 2.0 48.0 8192.0 4096.0 28.748776
9 2.0 48.0 16384.0 8192.0 29.420537
10 8.0 16.0 1989.0 15344.0 7.728812
11 4.0 16.0 4097.0 163.0 38.106455
12 2.0 16.0 8122.0 2159.0 49.793807
13 1.0 16.0 16281.0 7.0 1.635808
14 2.0 48.0 1021.0 1020.0 12.503968
15 2.0 48.0 2001.0 2048.0 13.502583
16 2.0 48.0 3996.0 9639.0 9.313698
17 2.0 48.0 8181.0 1021.0 91.321226
fused-attention-varlen-fwd-d128:
BATCH HQ HK N_CTX Triton
0 2.0 16.0 4.0 1024.0 5.716963
1 8.0 16.0 2.0 2048.0 4.230196
2 4.0 16.0 8.0 4096.0 16.136856
3 2.0 16.0 4.0 8192.0 13.651264
4 2.0 16.0 8.0 16384.0 82.597363
5 2.0 48.0 12.0 1024.0 17.080724
6 2.0 48.0 24.0 2048.0 17.039837
7 2.0 48.0 8.0 4096.0 55.623457
8 2.0 48.0 4.0 8192.0 79.230780
9 2.0 48.0 2.0 16384.0 87.740621
10 2.0 64.0 32.0 1024.0 22.524289
11 4.0 64.0 16.0 2048.0 33.392685
12 4.0 64.0 8.0 4096.0 37.548845
13 4.0 64.0 32.0 8192.0 66.148101
14 4.0 128.0 16.0 16384.0 81.158887
Hmm, these look expected on a MI210. What baseline are you comparing with to compare these as low?
Hi @vgokhale . Sorry for the late reply, I got it wrong Second, below are flash decoding results on an MI210:
fused-attention-d128-fwd-causal=False:
B Mq Mkv Hq Hkv K Triton
0 256.0 1.0 256.0 16.0 1.0 128.0 593.210161
1 128.0 1.0 512.0 16.0 1.0 128.0 588.659704
2 64.0 1.0 1024.0 16.0 1.0 128.0 582.193434
3 32.0 1.0 2048.0 16.0 1.0 128.0 578.703046
4 16.0 1.0 4096.0 16.0 1.0 128.0 581.533909
5 8.0 1.0 8192.0 16.0 1.0 128.0 581.341326
6 4.0 1.0 16384.0 16.0 1.0 128.0 590.296090
7 2.0 1.0 32768.0 16.0 1.0 128.0 580.572486
8 1.0 1.0 65536.0 16.0 1.0 128.0 582.854867
9 1.0 1.0 131072.0 16.0 1.0 128.0 580.364287
10 256.0 1.0 256.0 16.0 2.0 128.0 585.848987
11 128.0 1.0 512.0 16.0 2.0 128.0 583.298564
12 64.0 1.0 1024.0 16.0 2.0 128.0 579.619348
13 32.0 1.0 2048.0 16.0 2.0 128.0 587.042093
14 16.0 1.0 4096.0 16.0 2.0 128.0 581.362069
15 8.0 1.0 8192.0 16.0 2.0 128.0 586.579323
16 4.0 1.0 16384.0 16.0 2.0 128.0 587.934375
17 2.0 1.0 32768.0 16.0 2.0 128.0 586.112440
18 1.0 1.0 65536.0 16.0 2.0 128.0 588.137567
19 1.0 1.0 131072.0 16.0 2.0 128.0 585.379064
Does this result look ok?
What script are you using for flash decoding? I don't think we have one checked in at top of triton-mlir branch yet.
@vgokhale I use https://github.com/ROCm/triton/pull/492/files
@vgokhale hi, python/perf-kernels/06-attention-decode.py generates wrong results now
Mismatched elements: 8192 / 8192 (100.0%)
Greatest absolute difference: 0.4072265625 at index (0, 9, 0, 61) (up to 0.021 allowed)
Greatest relative difference: 0.407958984375 at index (0, 9, 0, 111) (up to 0 allowed)
FAILED decoding1.py::test_op_fwd_int4_kv[2-1-32768-16-1-128] - AssertionError: Tensor-likes are not close!
Mismatched elements: 4096 / 4096 (100.0%)
Greatest absolute difference: 0.8779296875 at index (0, 9, 0, 59) (up to 0.021 allowed)
Greatest relative difference: 0.87353515625 at index (0, 9, 0, 33) (up to 0 allowed)
FAILED decoding1.py::test_op_fwd_int4_kv[1-1-65536-16-1-128] - AssertionError: Tensor-likes are not close!
Mismatched elements: 2048 / 2048 (100.0%)
Greatest absolute difference: 0.99853515625 at index (0, 3, 0, 121) (up to 0.021 allowed)
Greatest relative difference: 0.99365234375 at index (0, 3, 0, 28) (up to 0 allowed)
FAILED decoding1.py::test_op_fwd_int4_kv[1-1-131072-16-1-128] - AssertionError: Tensor-likes are not close!
Mismatched elements: 2048 / 2048 (100.0%)
Greatest absolute difference: 1.0048828125 at index (0, 0, 0, 58) (up to 0.021 allowed)
Greatest relative difference: 1.0 at index (0, 0, 0, 0) (up to 0 allowed)
FAILED decoding1.py::test_op_fwd_int4_kv[8-1-8192-16-2-128] - AssertionError: Tensor-likes are not close!
Mismatched elements: 16384 / 16384 (100.0%)
Greatest absolute difference: 0.0791015625 at index (6, 11, 0, 55) (up to 0.021 allowed)
Greatest relative difference: 0.080810546875 at index (6, 11, 0, 55) (up to 0 allowed)
FAILED decoding1.py::test_op_fwd_int4_kv[4-1-16384-16-2-128] - AssertionError: Tensor-likes are not close!
Mismatched elements: 8192 / 8192 (100.0%)
Greatest absolute difference: 0.4091796875 at index (3, 1, 0, 122) (up to 0.021 allowed)
Greatest relative difference: 0.408447265625 at index (3, 1, 0, 97) (up to 0 allowed)
FAILED decoding1.py::test_op_fwd_int4_kv[2-1-32768-16-2-128] - AssertionError: Tensor-likes are not close!
Mismatched elements: 4096 / 4096 (100.0%)
Greatest absolute difference: 0.87939453125 at index (0, 9, 0, 55) (up to 0.021 allowed)
Greatest relative difference: 0.87646484375 at index (0, 9, 0, 90) (up to 0 allowed)
FAILED decoding1.py::test_op_fwd_int4_kv[1-1-65536-16-2-128] - AssertionError: Tensor-likes are not close!
Mismatched elements: 2048 / 2048 (100.0%)
Greatest absolute difference: 1.0 at index (0, 3, 0, 89) (up to 0.021 allowed)
Greatest relative difference: 0.994140625 at index (0, 9, 0, 2) (up to 0 allowed)
FAILED decoding1.py::test_op_fwd_int4_kv[1-1-131072-16-2-128] - AssertionError: Tensor-likes are not close!
Mismatched elements: 2048 / 2048 (100.0%)
Greatest absolute difference: 1.005859375 at index (0, 11, 0, 85) (up to 0.021 allowed)
Greatest relative difference: 1.0 at index (0, 0, 0, 0) (up to 0 allowed)
============================================================================================= 19 failed, 22 passed in 69.18s (0:01:09
Hi @scxiao, since you have sent a PR, I imagine it passes all unit tests?
@vgokhale hi, After making modifications based on the pull request #541 1、Official python/perf-kernels/flash-attention.py (unit test: test_op_bwd)fails with big absolute differences,
pytest flash-attention.py -k test_op_bwd
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD',
[(4, 48, 1024, 64),
(4, 48, 2048, 64),
(4, 48, 4096, 64),
FAILED flash-attention.py::test_op_bwd[4-48-1024-64] - AssertionError: Tensor-likes are not close!
Mismatched elements: 8701164 / 12582912 (69.2%)
Greatest absolute difference: 6.38818359375 at index (3, 40, 0, 19) (up to 0.05 allowed)
Greatest relative difference: inf at index (2, 28, 26, 36) (up to 0 allowed)
FAILED flash-attention.py::test_op_bwd[4-48-2048-64] - AssertionError: Tensor-likes are not close!
Mismatched elements: 12801585 / 25165824 (50.9%)
Greatest absolute difference: 7.03271484375 at index (2, 16, 0, 6) (up to 0.05 allowed)
Greatest relative difference: inf at index (2, 39, 1344, 41) (up to 0 allowed)
FAILED flash-attention.py::test_op_bwd[4-48-4096-64] - AssertionError: Tensor-likes are not close!
Mismatched elements: 6034803 / 50331648 (12.0%)
Greatest absolute difference: 5.860401153564453 at index (1, 25, 0, 57) (up to 0.05 allowed)
Greatest relative difference: inf at index (1, 0, 48, 1) (up to 0 allowed)
2、Official python/perf-kernels/flash-attention.py(unit test: test_op_bwd) fails at the (1, 16, 8192, 64) with memory access fault .
pytest flash-attention.py -k test_op_bwd
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD',
(1, 16, 8192, 64),
flash-attention.py Memory access fault by GPU node-10 (Agent handle: 0x6e2de00) on address 0x7eff6bdeb000. Reason: Unknown.
For FA bwd kernel, please use python/tutorials/06-fused-attention.py.
We are currently working on supporting bwd in the perf-kernels folder - until then the tutorials folder is the right one to use for bwd.
Closing due to no new updates.
Problem Description
Hi, Can somebody please take a look a this ? I just tested this code and it generates wrong results.
Operating System
NAME="Ubuntu" VERSION="20.04.5 LTS (Focal Fossa)"
CPU
Intel(R) Xeon(R) CPU E5-2680 v3 @ 2.50GHz
GPU
AMD Instinct MI210
ROCm Version
ROCm 5.5.0
ROCm Component
No response
Steps to Reproduce
Steps
errors
(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support
Additional Information
No response