Closed etiennemlb closed 2 months ago
Which device do you have? The main
branch should work for MI GPUs.
I have MI250X and MI300A.
Afaik, the work in this repository was merged into the upstream flash-attn. (see https://github.com/Dao-AILab/flash-attention/pull/1010).
I just ended up using the tags advertised on the upstream: https://github.com/Dao-AILab/flash-attention/tags
It works well provided you play with with GPU_ARCHS and BUILD_TARGET.
I am glad that it works on your devices.
By the way, may I request a moment of your time? I am currently investigating AMD support for Flash Attention as a RX 7900 XTX owner. Based on the data I’ve gathered, the performance of Flash Attention on RDNA 3 devices does not appear to be particularly impressive. Therefore, I am quite curious about the performance of Flash Attention on CDNA devices.
I have a benchmark script here (set TEST_FLASH_TRITON = False
to exclude invalid tests):
You can prepare the environment like:
python3 -m venv venv
source venv/bin/activate
pip3 install pytest pandas matplotlib
# the stable torch should include aotriton that already works for cdna devices
pip3 install torch --index-url https://download.pytorch.org/whl/rocm6.1
# i am not sure if this triton works with cdna devices out of the box
pip3 install triton
# guessed steps for installing flash attention in your env
git clone https://github.com/Dao-AILab/flash-attention
cd flash-attention
python3 setup.py install
# run the benchmark, which should only take a few seconds
python3 performance.py
I am particularly interested in the numbers of "Flash", "Torch" and "Torch Math".
Thanks in advance :)
I'll see what I can do, I also have results that came out of this script: https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py
The flash attention triton didnt work outside of the box using your script, I disabled it:
fused-attention-batch2-head8-d128-fwd-causal=False:
N_CTX Flash Torch Torch Math
0 77.0 1.412917 1.138847 1.235458
1 256.0 10.955014 11.236593 8.967513
2 1024.0 50.381105 37.893675 23.948309
3 4096.0 93.834698 67.726780 27.688738
4 8192.0 97.697297 69.410223 32.837748
5 9216.0 99.613822 71.056432 32.647771
Using the official flash attn benchmark I have:
MI250X
### causal=False, headdim=128, batch_size=2, seqlen=8192 ###
Flash2 fwd: 100.18 TFLOPs/s, bwd: 49.36 TFLOPs/s, fwd + bwd: 57.73 TFLOPs/s
Pytorch fwd: 34.59 TFLOPs/s, bwd: 50.20 TFLOPs/s, fwd + bwd: 44.46 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s # OOM
### causal=False, headdim=128, batch_size=1, seqlen=16384 ###
Flash2 fwd: 101.16 TFLOPs/s, bwd: 49.67 TFLOPs/s, fwd + bwd: 58.12 TFLOPs/s
Pytorch fwd: 27.96 TFLOPs/s, bwd: 40.23 TFLOPs/s, fwd + bwd: 35.75 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s # OOM
### causal=True, headdim=128, batch_size=2, seqlen=8192 ###
Flash2 fwd: 92.83 TFLOPs/s, bwd: 46.02 TFLOPs/s, fwd + bwd: 53.77 TFLOPs/s
Pytorch fwd: 11.95 TFLOPs/s, bwd: 26.27 TFLOPs/s, fwd + bwd: 19.57 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s # OOM
### causal=True, headdim=128, batch_size=1, seqlen=16384 ###
Flash2 fwd: 94.49 TFLOPs/s, bwd: 49.66 TFLOPs/s, fwd + bwd: 57.45 TFLOPs/s
Pytorch fwd: 10.01 TFLOPs/s, bwd: 21.15 TFLOPs/s, fwd + bwd: 16.05 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s # OOM
MI300A (ot MI300X):
### causal=False, headdim=128, batch_size=2, seqlen=8192 ###
Flash2 fwd: 377.58 TFLOPs/s, bwd: 74.65 TFLOPs/s, fwd + bwd: 96.85 TFLOPs/s
Pytorch fwd: 77.26 TFLOPs/s, bwd: 118.31 TFLOPs/s, fwd + bwd: 102.71 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s
### causal=False, headdim=128, batch_size=1, seqlen=16384 ###
Flash2 fwd: 384.45 TFLOPs/s, bwd: 74.77 TFLOPs/s, fwd + bwd: 97.13 TFLOPs/s
Pytorch fwd: 71.74 TFLOPs/s, bwd: 120.53 TFLOPs/s, fwd + bwd: 100.92 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s # OOM
### causal=True, headdim=128, batch_size=2, seqlen=8192 ###
Flash2 fwd: 288.32 TFLOPs/s, bwd: 68.16 TFLOPs/s, fwd + bwd: 87.18 TFLOPs/s
Pytorch fwd: 27.15 TFLOPs/s, bwd: 62.42 TFLOPs/s, fwd + bwd: 45.52 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s
### causal=True, headdim=128, batch_size=1, seqlen=16384 ###
Flash2 fwd: 291.85 TFLOPs/s, bwd: 71.77 TFLOPs/s, fwd + bwd: 91.48 TFLOPs/s
Pytorch fwd: 25.17 TFLOPs/s, bwd: 64.43 TFLOPs/s, fwd + bwd: 44.57 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s # OOM
Thank you for sharing!
Meanwhile, it seems that AMD is going to merge CDNA and RDNA into a new UDNA architecture:
Description of errors
Which banch or tag, or commit sha should I use to build flash attention ?
In the past I use the always use the
flash_attention_for_rocm
is that still recommend?Attach any links, screenshots, or additional evidence you think will be helpful.
No response