ROCm / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
141 stars 46 forks source link

[Documentation]: Which branch, tag or commit sha should I use to build flash attention for AMD devices ? #82

Closed etiennemlb closed 2 months ago

etiennemlb commented 2 months ago

Description of errors

Which banch or tag, or commit sha should I use to build flash attention ?

In the past I use the always use the flash_attention_for_rocm is that still recommend?

Attach any links, screenshots, or additional evidence you think will be helpful.

No response

evshiron commented 2 months ago

Which device do you have? The main branch should work for MI GPUs.

etiennemlb commented 2 months ago

I have MI250X and MI300A.

Afaik, the work in this repository was merged into the upstream flash-attn. (see https://github.com/Dao-AILab/flash-attention/pull/1010).

I just ended up using the tags advertised on the upstream: https://github.com/Dao-AILab/flash-attention/tags

It works well provided you play with with GPU_ARCHS and BUILD_TARGET.

evshiron commented 2 months ago

I am glad that it works on your devices.

By the way, may I request a moment of your time? I am currently investigating AMD support for Flash Attention as a RX 7900 XTX owner. Based on the data I’ve gathered, the performance of Flash Attention on RDNA 3 devices does not appear to be particularly impressive. Therefore, I am quite curious about the performance of Flash Attention on CDNA devices.

I have a benchmark script here (set TEST_FLASH_TRITON = False to exclude invalid tests):

You can prepare the environment like:

python3 -m venv venv
source venv/bin/activate

pip3 install pytest pandas matplotlib

# the stable torch should include aotriton that already works for cdna devices
pip3 install torch --index-url https://download.pytorch.org/whl/rocm6.1

# i am not sure if this triton works with cdna devices out of the box
pip3 install triton

# guessed steps for installing flash attention in your env
git clone https://github.com/Dao-AILab/flash-attention
cd flash-attention
python3 setup.py install

# run the benchmark, which should only take a few seconds
python3 performance.py

I am particularly interested in the numbers of "Flash", "Torch" and "Torch Math".

Thanks in advance :)

etiennemlb commented 2 months ago

I'll see what I can do, I also have results that came out of this script: https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py

etiennemlb commented 2 months ago

The flash attention triton didnt work outside of the box using your script, I disabled it:

fused-attention-batch2-head8-d128-fwd-causal=False:
    N_CTX      Flash      Torch  Torch Math
0    77.0   1.412917   1.138847    1.235458
1   256.0  10.955014  11.236593    8.967513
2  1024.0  50.381105  37.893675   23.948309
3  4096.0  93.834698  67.726780   27.688738
4  8192.0  97.697297  69.410223   32.837748
5  9216.0  99.613822  71.056432   32.647771

Using the official flash attn benchmark I have:

MI250X

### causal=False, headdim=128, batch_size=2, seqlen=8192 ###
Flash2 fwd: 100.18 TFLOPs/s, bwd: 49.36 TFLOPs/s, fwd + bwd: 57.73 TFLOPs/s
Pytorch fwd: 34.59 TFLOPs/s, bwd: 50.20 TFLOPs/s, fwd + bwd: 44.46 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s # OOM
### causal=False, headdim=128, batch_size=1, seqlen=16384 ###
Flash2 fwd: 101.16 TFLOPs/s, bwd: 49.67 TFLOPs/s, fwd + bwd: 58.12 TFLOPs/s
Pytorch fwd: 27.96 TFLOPs/s, bwd: 40.23 TFLOPs/s, fwd + bwd: 35.75 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s # OOM
### causal=True, headdim=128, batch_size=2, seqlen=8192 ###
Flash2 fwd: 92.83 TFLOPs/s, bwd: 46.02 TFLOPs/s, fwd + bwd: 53.77 TFLOPs/s
Pytorch fwd: 11.95 TFLOPs/s, bwd: 26.27 TFLOPs/s, fwd + bwd: 19.57 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s # OOM
### causal=True, headdim=128, batch_size=1, seqlen=16384 ###
Flash2 fwd: 94.49 TFLOPs/s, bwd: 49.66 TFLOPs/s, fwd + bwd: 57.45 TFLOPs/s
Pytorch fwd: 10.01 TFLOPs/s, bwd: 21.15 TFLOPs/s, fwd + bwd: 16.05 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s # OOM

MI300A (ot MI300X):

### causal=False, headdim=128, batch_size=2, seqlen=8192 ###
Flash2 fwd: 377.58 TFLOPs/s, bwd: 74.65 TFLOPs/s, fwd + bwd: 96.85 TFLOPs/s
Pytorch fwd: 77.26 TFLOPs/s, bwd: 118.31 TFLOPs/s, fwd + bwd: 102.71 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s
### causal=False, headdim=128, batch_size=1, seqlen=16384 ###
Flash2 fwd: 384.45 TFLOPs/s, bwd: 74.77 TFLOPs/s, fwd + bwd: 97.13 TFLOPs/s
Pytorch fwd: 71.74 TFLOPs/s, bwd: 120.53 TFLOPs/s, fwd + bwd: 100.92 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s # OOM
### causal=True, headdim=128, batch_size=2, seqlen=8192 ###
Flash2 fwd: 288.32 TFLOPs/s, bwd: 68.16 TFLOPs/s, fwd + bwd: 87.18 TFLOPs/s
Pytorch fwd: 27.15 TFLOPs/s, bwd: 62.42 TFLOPs/s, fwd + bwd: 45.52 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s
### causal=True, headdim=128, batch_size=1, seqlen=16384 ###
Flash2 fwd: 291.85 TFLOPs/s, bwd: 71.77 TFLOPs/s, fwd + bwd: 91.48 TFLOPs/s
Pytorch fwd: 25.17 TFLOPs/s, bwd: 64.43 TFLOPs/s, fwd + bwd: 44.57 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s # OOM
evshiron commented 2 months ago

Thank you for sharing!

Meanwhile, it seems that AMD is going to merge CDNA and RDNA into a new UDNA architecture: