ROCm / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
142 stars 46 forks source link

RDNA3 support #27

Open WilliamGazeley opened 12 months ago

WilliamGazeley commented 12 months ago

Great work so far. I'm trying to run vLLM on my 7900XTX cards and was wondering if there were any plans to support RDNA3?

sdli1995 commented 12 months ago

A CK disscussion has show a branch which has flash-attention kernel impl and already work in ait https://github.com/ROCmSoftwarePlatform/composable_kernel/discussions/1032 are there any barrier on RNDA3 support ?

dejay-vu commented 12 months ago

Hi @WilliamGazeley and @sdli1995, I wanted to update you on the attention kernels for the NAVI platform. My colleague @aska-0096 did activate them. However, these Flash-Attention kernels were initially developed for MI devices, which operate on a distinct set of CK kernels.

In essence, the issue is that we haven't yet integrated these kernels into our current API. I plan to work on this integration in my spare time.

AlpinDale commented 12 months ago

Great work on the fork, @howiejayz

Will it take too long to port the kernels for gfx1100 and other non-MI architectures? I need the kernels for my project and I'd be willing to help out if I can.

dejay-vu commented 11 months ago

Thanks for reaching out and offering to help @AlpinDale! I'm currently tied up with a few other projects, so I can't give an exact timeframe for porting the kernels for gfx1100 and other architectures. But I'm definitely planning to tackle this soon. The first step will be creating a new code path for gfx110x, considering their CK kernels are only for forward ops.

I'm totally open to suggestions or any help you can provide. It'd be great to have some extra hands on this. Let me know if you're interested!

evshiron commented 11 months ago

I am a complete novice in this field, but a few months ago I managed to make Composable Kernel, Flash Attention and PyTorch work together for my RX 7900 XTX (see here, sort by performance, and look for the first one). Although I was able get that Flash Attention implementation "working" in the end, the generated images were meaningless, and I gave up because I didn't know how to fix it. Here are the relevant branch links, and I hope they can be of some help to you:

I made a grouped fused kernel by Frankensteining the batched fused kernel, which matched the call signatures in this repo at that time. However, that self-made kernel might just be broken.

dejay-vu commented 11 months ago

Hi @evshiron! First off, I must say I'm seriously impressed by your work! It's quite an achievement, and the resources you've provided are invaluable.

I've had the opportunity to build your implementation on gfx1100, and I'm pleased to report that the build was successful. However, I encountered an issue with the unit tests not passing due to incorrect results in the forward pass:

assert (output - output_ref).abs().max().item() <= 2 (output_pt - output_ref).abs().max().item() AssertionError: assert 0.109619140625 <= (2 0.000244140625)

which is likely stemming from incorrect parameter settings in the CK kernels. I guess this should be the reason why the output image become meaningless.

Despite this, your work has been immensely helpful! This will massively speed up the navi porting process for the v2 implementation.

evshiron commented 11 months ago

@howiejayz

I'm glad that my humble work could be of some help. I am indeed unfamiliar with this field, so I can only leave it to professionals. Furthermore, as you can see, even though I managed to compile it, the improvement in the benchmark is quite limited (I didn't use the specific commit showed here). I hope it's just an issue with my implementation and I look forward to better performance in future implementations.

dejay-vu commented 11 months ago

Guys I have added the batched forward(consistent sequence lengths) support for gfx1100, gfx1101, gfx1102 under this branch. Thanks to @aska-0096's CK kernels. The implementation is still under development and there are a lot of things to fine-tune. For now I see the performance is generally better when head dim = 64

To install just use pip install .

I only had the chance to test it on gfx1100 but I expect it works as well for the other two. Let me know if there is any issue! The docker I used to test is rocm/pytorch:latest where torch==2.1.0

xzuyn commented 11 months ago

under this branch.

benchmark_flash_attention_forward.py works, but benchmark_flash_attention.py doesn't. Forward speeds look pretty nice.

Using a 7900XTX with torch 2.2.0.dev20231209+rocm5.7.

Results for `benchmark_flash_attention_forward.py` ``` ### causal=False, headdim=64, batch_size=32, seqlen=512 ### Flash2 fwd: 38.52 TFLOPs/s Pytorch fwd: 13.37 TFLOPs/s Triton fwd: 0.00 TFLOPs/s ### causal=False, headdim=64, batch_size=16, seqlen=1024 ### Flash2 fwd: 39.33 TFLOPs/s Pytorch fwd: 14.22 TFLOPs/s Triton fwd: 0.00 TFLOPs/s ### causal=False, headdim=64, batch_size=8, seqlen=2048 ### Flash2 fwd: 41.23 TFLOPs/s Pytorch fwd: 16.04 TFLOPs/s Triton fwd: 0.00 TFLOPs/s ### causal=False, headdim=64, batch_size=4, seqlen=4096 ### Flash2 fwd: 42.06 TFLOPs/s Pytorch fwd: 18.02 TFLOPs/s Triton fwd: 0.00 TFLOPs/s ### causal=False, headdim=64, batch_size=2, seqlen=8192 ### Flash2 fwd: 42.02 TFLOPs/s Pytorch fwd: 19.16 TFLOPs/s Triton fwd: 0.00 TFLOPs/s ### causal=False, headdim=64, batch_size=1, seqlen=16384 ### Flash2 fwd: 37.27 TFLOPs/s Pytorch fwd: 0.00 TFLOPs/s Triton fwd: 0.00 TFLOPs/s ### causal=False, headdim=128, batch_size=32, seqlen=512 ### Flash2 fwd: 28.27 TFLOPs/s Pytorch fwd: 20.43 TFLOPs/s Triton fwd: 0.00 TFLOPs/s ### causal=False, headdim=128, batch_size=16, seqlen=1024 ### Flash2 fwd: 29.38 TFLOPs/s Pytorch fwd: 21.05 TFLOPs/s Triton fwd: 0.00 TFLOPs/s ### causal=False, headdim=128, batch_size=8, seqlen=2048 ### Flash2 fwd: 30.49 TFLOPs/s Pytorch fwd: 25.23 TFLOPs/s Triton fwd: 0.00 TFLOPs/s ### causal=False, headdim=128, batch_size=4, seqlen=4096 ### Flash2 fwd: 31.00 TFLOPs/s Pytorch fwd: 26.99 TFLOPs/s Triton fwd: 0.00 TFLOPs/s ### causal=False, headdim=128, batch_size=2, seqlen=8192 ### Flash2 fwd: 27.50 TFLOPs/s Pytorch fwd: 28.47 TFLOPs/s Triton fwd: 0.00 TFLOPs/s ### causal=False, headdim=128, batch_size=1, seqlen=16384 ### Flash2 fwd: 20.67 TFLOPs/s Pytorch fwd: 0.00 TFLOPs/s Triton fwd: 0.00 TFLOPs/s ### causal=True, headdim=64, batch_size=32, seqlen=512 ### Flash2 fwd: 24.02 TFLOPs/s Pytorch fwd: 5.07 TFLOPs/s Triton fwd: 0.00 TFLOPs/s ### causal=True, headdim=64, batch_size=16, seqlen=1024 ### Flash2 fwd: 29.08 TFLOPs/s Pytorch fwd: 5.48 TFLOPs/s Triton fwd: 0.00 TFLOPs/s ### causal=True, headdim=64, batch_size=8, seqlen=2048 ### Flash2 fwd: 33.49 TFLOPs/s Pytorch fwd: 5.84 TFLOPs/s Triton fwd: 0.00 TFLOPs/s ### causal=True, headdim=64, batch_size=4, seqlen=4096 ### Flash2 fwd: 36.44 TFLOPs/s Pytorch fwd: 6.21 TFLOPs/s Triton fwd: 0.00 TFLOPs/s ### causal=True, headdim=64, batch_size=2, seqlen=8192 ### Flash2 fwd: 38.54 TFLOPs/s Pytorch fwd: 0.00 TFLOPs/s Triton fwd: 0.00 TFLOPs/s ### causal=True, headdim=64, batch_size=1, seqlen=16384 ### Flash2 fwd: 39.70 TFLOPs/s Pytorch fwd: 0.00 TFLOPs/s Triton fwd: 0.00 TFLOPs/s ### causal=True, headdim=128, batch_size=32, seqlen=512 ### Flash2 fwd: 17.89 TFLOPs/s Pytorch fwd: 8.42 TFLOPs/s Triton fwd: 0.00 TFLOPs/s ### causal=True, headdim=128, batch_size=16, seqlen=1024 ### Flash2 fwd: 21.69 TFLOPs/s Pytorch fwd: 8.68 TFLOPs/s Triton fwd: 0.00 TFLOPs/s ### causal=True, headdim=128, batch_size=8, seqlen=2048 ### Flash2 fwd: 25.64 TFLOPs/s Pytorch fwd: 9.78 TFLOPs/s Triton fwd: 0.00 TFLOPs/s ### causal=True, headdim=128, batch_size=4, seqlen=4096 ### Flash2 fwd: 27.06 TFLOPs/s Pytorch fwd: 10.01 TFLOPs/s Triton fwd: 0.00 TFLOPs/s ### causal=True, headdim=128, batch_size=2, seqlen=8192 ### Flash2 fwd: 27.43 TFLOPs/s Pytorch fwd: 9.87 TFLOPs/s Triton fwd: 0.00 TFLOPs/s ### causal=True, headdim=128, batch_size=1, seqlen=16384 ### Flash2 fwd: 24.68 TFLOPs/s Pytorch fwd: 0.00 TFLOPs/s Triton fwd: 0.00 TFLOPs/s ```
Results for `test_flash_attn_wmma_rocm.py` `=============== 125 failed, 2148 passed, 4606 skipped in 46.01s ================` Full Log: [test_flash_attn_wmma_rocm.log](https://github.com/ROCmSoftwarePlatform/flash-attention/files/13626748/test_flash_attn_wmma_rocm.log)
Error for `benchmark_flash_attention.py` ``` > python benchmarks/benchmark_flash_attention.py Traceback (most recent call last): File "/home/USER/clones/LLaMA-Efficient-Tuning/venv/flash-attention/benchmarks/benchmark_flash_attention.py", line 97, in f, b = time_fwd_bwd( File "/home/USER/clones/LLaMA-Efficient-Tuning/venv/flash-attention/benchmarks/benchmark_flash_attention.py", line 66, in time_fwd_bwd time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs) File "/home/USER/clones/LLaMA-Efficient-Tuning/venv/lib/python3.10/site-packages/flash_attn/utils/benchmark.py", line 99, in benchmark_fwd_bwd benchmark_backward(fn, *inputs, grad=grad, repeats=repeats, desc=desc, verbose=verbose, File "/home/USER/clones/LLaMA-Efficient-Tuning/venv/lib/python3.10/site-packages/flash_attn/utils/benchmark.py", line 53, in benchmark_backward m = t.timeit(repeats) File "/home/USER/clones/LLaMA-Efficient-Tuning/venv/lib/python3.10/site-packages/torch/utils/benchmark/utils/timer.py", line 274, in timeit self._timeit(number=max(int(number // 100), 2)) File "/home/USER/clones/LLaMA-Efficient-Tuning/venv/lib/python3.10/site-packages/torch/utils/benchmark/utils/timer.py", line 264, in _timeit return max(self._timer.timeit(number), 1e-9) File "/usr/lib/python3.10/timeit.py", line 178, in timeit timing = self.inner(it, self.timer) File "", line 6, in inner File "/home/USER/clones/LLaMA-Efficient-Tuning/venv/lib/python3.10/site-packages/flash_attn/utils/benchmark.py", line 46, in f y.backward(grad, retain_graph=True) File "/home/USER/clones/LLaMA-Efficient-Tuning/venv/lib/python3.10/site-packages/torch/_tensor.py", line 503, in backward torch.autograd.backward( File "/home/USER/clones/LLaMA-Efficient-Tuning/venv/lib/python3.10/site-packages/torch/autograd/__init__.py", line 266, in backward Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass File "/home/USER/clones/LLaMA-Efficient-Tuning/venv/lib/python3.10/site-packages/torch/autograd/function.py", line 289, in apply return user_fn(self, *args) File "/home/USER/clones/LLaMA-Efficient-Tuning/venv/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 109, in backward _flash_attn_backward( File "/home/USER/clones/LLaMA-Efficient-Tuning/venv/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 66, in _flash_attn_backward dq, dk, dv, softmax_d, = flash_attn_cuda.bwd( TypeError: bwd(): incompatible function arguments. The following argument types are supported: 1. () -> None ```
AlpinDale commented 11 months ago

Some benchmark results.

RTX 4090

### causal=False, headdim=64, batch_size=32, seqlen=512 ###
Flash2 fwd: 150.90 TFLOPs/s,
Pytorch fwd: 20.23 TFLOPs/s,
### causal=False, headdim=64, batch_size=16, seqlen=1024 ###
Flash2 fwd: 154.49 TFLOPs/s,
Pytorch fwd: 23.69 TFLOPs/s,
### causal=False, headdim=64, batch_size=8, seqlen=2048 ###
Flash2 fwd: 171.80 TFLOPs/s,
Pytorch fwd: 26.21 TFLOPs/s,
### causal=False, headdim=64, batch_size=4, seqlen=4096 ###
Flash2 fwd: 172.81 TFLOPs/s,
Pytorch fwd: 27.89 TFLOPs/s,
### causal=False, headdim=64, batch_size=2, seqlen=8192 ###
Flash2 fwd: 172.96 TFLOPs/s,
Pytorch fwd: 0.00 TFLOPs/s,
### causal=False, headdim=64, batch_size=1, seqlen=16384 ###
Flash2 fwd: 173.04 TFLOPs/s,
Pytorch fwd: 0.00 TFLOPs/s,

7900 XTX

### causal=False, headdim=64, batch_size=32, seqlen=512 ###
Flash2 fwd: 42.98 TFLOPs/s
Pytorch fwd: 14.38 TFLOPs/s
### causal=False, headdim=64, batch_size=16, seqlen=1024 ###
Flash2 fwd: 44.31 TFLOPs/s
Pytorch fwd: 14.83 TFLOPs/s
### causal=False, headdim=64, batch_size=8, seqlen=2048 ###
Flash2 fwd: 48.25 TFLOPs/s
Pytorch fwd: 17.32 TFLOPs/s
### causal=False, headdim=64, batch_size=4, seqlen=4096 ###
Flash2 fwd: 47.55 TFLOPs/s
Pytorch fwd: 19.40 TFLOPs/s
### causal=False, headdim=64, batch_size=2, seqlen=8192 ###
Flash2 fwd: 38.40 TFLOPs/s
Pytorch fwd: 20.19 TFLOPs/s
### causal=False, headdim=64, batch_size=1, seqlen=16384 ###
Flash2 fwd: 41.01 TFLOPs/s
Pytorch fwd: 0.00 TFLOPs/s
sdli1995 commented 11 months ago

Some benchmark results.

RTX 4090

### causal=False, headdim=64, batch_size=32, seqlen=512 ###
Flash2 fwd: 150.90 TFLOPs/s,
Pytorch fwd: 20.23 TFLOPs/s,
### causal=False, headdim=64, batch_size=16, seqlen=1024 ###
Flash2 fwd: 154.49 TFLOPs/s,
Pytorch fwd: 23.69 TFLOPs/s,
### causal=False, headdim=64, batch_size=8, seqlen=2048 ###
Flash2 fwd: 171.80 TFLOPs/s,
Pytorch fwd: 26.21 TFLOPs/s,
### causal=False, headdim=64, batch_size=4, seqlen=4096 ###
Flash2 fwd: 172.81 TFLOPs/s,
Pytorch fwd: 27.89 TFLOPs/s,
### causal=False, headdim=64, batch_size=2, seqlen=8192 ###
Flash2 fwd: 172.96 TFLOPs/s,
Pytorch fwd: 0.00 TFLOPs/s,
### causal=False, headdim=64, batch_size=1, seqlen=16384 ###
Flash2 fwd: 173.04 TFLOPs/s,
Pytorch fwd: 0.00 TFLOPs/s,

7900 XTX

### causal=False, headdim=64, batch_size=32, seqlen=512 ###
Flash2 fwd: 42.98 TFLOPs/s
Pytorch fwd: 14.38 TFLOPs/s
### causal=False, headdim=64, batch_size=16, seqlen=1024 ###
Flash2 fwd: 44.31 TFLOPs/s
Pytorch fwd: 14.83 TFLOPs/s
### causal=False, headdim=64, batch_size=8, seqlen=2048 ###
Flash2 fwd: 48.25 TFLOPs/s
Pytorch fwd: 17.32 TFLOPs/s
### causal=False, headdim=64, batch_size=4, seqlen=4096 ###
Flash2 fwd: 47.55 TFLOPs/s
Pytorch fwd: 19.40 TFLOPs/s
### causal=False, headdim=64, batch_size=2, seqlen=8192 ###
Flash2 fwd: 38.40 TFLOPs/s
Pytorch fwd: 20.19 TFLOPs/s
### causal=False, headdim=64, batch_size=1, seqlen=16384 ###
Flash2 fwd: 41.01 TFLOPs/s
Pytorch fwd: 0.00 TFLOPs/s

4090 fp16 accumulate fp16 tensorcore performance is 330T ,while 7900xtx is 120T the better reference nvidia card is rtx3090

aska-0096 commented 11 months ago

Some benchmark results. RTX 4090

### causal=False, headdim=64, batch_size=32, seqlen=512 ###
Flash2 fwd: 150.90 TFLOPs/s,
Pytorch fwd: 20.23 TFLOPs/s,
### causal=False, headdim=64, batch_size=16, seqlen=1024 ###
Flash2 fwd: 154.49 TFLOPs/s,
Pytorch fwd: 23.69 TFLOPs/s,
### causal=False, headdim=64, batch_size=8, seqlen=2048 ###
Flash2 fwd: 171.80 TFLOPs/s,
Pytorch fwd: 26.21 TFLOPs/s,
### causal=False, headdim=64, batch_size=4, seqlen=4096 ###
Flash2 fwd: 172.81 TFLOPs/s,
Pytorch fwd: 27.89 TFLOPs/s,
### causal=False, headdim=64, batch_size=2, seqlen=8192 ###
Flash2 fwd: 172.96 TFLOPs/s,
Pytorch fwd: 0.00 TFLOPs/s,
### causal=False, headdim=64, batch_size=1, seqlen=16384 ###
Flash2 fwd: 173.04 TFLOPs/s,
Pytorch fwd: 0.00 TFLOPs/s,

7900 XTX

### causal=False, headdim=64, batch_size=32, seqlen=512 ###
Flash2 fwd: 42.98 TFLOPs/s
Pytorch fwd: 14.38 TFLOPs/s
### causal=False, headdim=64, batch_size=16, seqlen=1024 ###
Flash2 fwd: 44.31 TFLOPs/s
Pytorch fwd: 14.83 TFLOPs/s
### causal=False, headdim=64, batch_size=8, seqlen=2048 ###
Flash2 fwd: 48.25 TFLOPs/s
Pytorch fwd: 17.32 TFLOPs/s
### causal=False, headdim=64, batch_size=4, seqlen=4096 ###
Flash2 fwd: 47.55 TFLOPs/s
Pytorch fwd: 19.40 TFLOPs/s
### causal=False, headdim=64, batch_size=2, seqlen=8192 ###
Flash2 fwd: 38.40 TFLOPs/s
Pytorch fwd: 20.19 TFLOPs/s
### causal=False, headdim=64, batch_size=1, seqlen=16384 ###
Flash2 fwd: 41.01 TFLOPs/s
Pytorch fwd: 0.00 TFLOPs/s

4090 fp16 accumulate fp16 tensorcore performance is 330T ,while 7900xtx is 120T the better reference nvidia card is rtx3090

Thanks for the benchmark data. We are going to launch a new version of Composable Kernel with better flash-attention performance. Adapt the optimization on RDNA3 is in my plan.

AlpinDale commented 11 months ago

@sdli1995 here's the benchmarks with a 3090:

### causal=False, headdim=64, batch_size=32, seqlen=512 ###
Flash2 fwd: 65.38 TFLOPs/s,
Pytorch fwd: 18.38 TFLOPs/s,
### causal=False, headdim=64, batch_size=16, seqlen=1024 ###
Flash2 fwd: 72.94 TFLOPs/s,
Pytorch fwd: 21.69 TFLOPs/s,
### causal=False, headdim=64, batch_size=8, seqlen=2048 ###
Flash2 fwd: 74.11 TFLOPs/s,
Pytorch fwd: 18.92 TFLOPs/s,
### causal=False, headdim=64, batch_size=4, seqlen=4096 ###
Flash2 fwd: 74.98 TFLOPs/s,
Pytorch fwd: 22.27 TFLOPs/s,
### causal=False, headdim=64, batch_size=2, seqlen=8192 ###
Flash2 fwd: 75.06 TFLOPs/s,
Pytorch fwd: 0.00 TFLOPs/s,
### causal=False, headdim=64, batch_size=1, seqlen=16384 ###
Flash2 fwd: 75.12 TFLOPs/s,
Pytorch fwd: 0.00 TFLOPs/s,
AlpinDale commented 11 months ago

Any updates on this?

Wintoplay commented 11 months ago

We need official support for flash attention

ewof commented 11 months ago

trust bro, be patient don't rush them

gel-crabs commented 11 months ago

I've been using the howiejayz/navi_support branch on here with stable-diffusion-webui for a few weeks now. The implementation is perfect.

On an RX 7800 XT, it speeds it/s up from 1.75 it/s to 2 it/s, all while massively decreasing VRAM usage.

Kademo15 commented 11 months ago

I've been using the howiejayz/navi_support branch on here with stable-diffusion-webui for a few weeks now. The implementation is perfect.

On an RX 7800 XT, it speeds it/s up from 1.75 it/s to 2 it/s, all while massively decreasing VRAM usage.

Could you please provide more information about how. Did you just install the branch install it and it worked out of the box or did you have to change code of the webui you are using ?

Wintoplay commented 11 months ago

@gel-crabs I failed to install flash-attn for Navi. please give more info

gel-crabs commented 11 months ago

@Kademo15 @Wintoplay

Alright, I'm going to try to give instructions on how I got this to work. If you're on Arch, I have a very amateur PKGBUILD (requiring --skipinteg) that gets it to work. You need to go into the PKGBUILD and replace GPU_ARCHS=gfx1101 with your GPU's architecture and MAX_JOBS to however many CPU cores you have. I can only confirm it will work on gfx11+. The patch just changes the C++ standard from c++20 to c++17 to allow it to build.

python-flash-attention.tar.gz

If you aren't on Arch, you can generally just follow the commands and install the python wheel file afterwards, in your virtualenv if you're using one. You can clone the repo with git clone https://github.com/ROCmSoftwarePlatform/flash-attention.git -b howiejayz/navi_support --depth=1 in this case.

Now for webui. You will have to use a patch that has been closed since it will be obsolete once AMD finishes xformers support.

https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/11902.patch

That's the link to the raw patch file, use patch -p1 < 11902.patch in the webui directory to apply it to your webui. Please run patch -p1 --dry-run < 11902.patch first so it won't screw up your installation if it doesn't apply correctly. We're not done yet, however.

https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/11902

In the discussion for this patch, I posted a long comment on getting it to work. (https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/11902#issuecomment-1799512269). The important info from that is the 2 code blocks you need to change manually.

After that, add --flash-attn to your command line arguments and it should work. If you get slower or the same speed, flash attention isn't working. You may get a HIP OOM at the end of generation if you're using a higher resolution than usual, as it needs to switch back to SDP at the end due to not supporting a head dim over 128.

If you get an error involving a flash-attn so not loading, rebuild the PKGBUILD but change CC=clang and CXX=clang++ to CC=gcc and CXX=g++.

Beinsezii commented 11 months ago

Switching setup.py to c++17 built successfully on gfx1100

Seems to work in transformers, as a 7b model OOMs @ >8k context using the default attention but doesn't even crack 20 gigs with FA2 enabled. Interestingly I lose about 1 t/s though?

I'll have to see if I can monkeypatch it into Diffusers...

Wintoplay commented 11 months ago

@gel-crabs that and tried FLASH_ATTENTION_INTERNAL_USE_RTN=1 pip install . it just say FFFFF result for the testing

(I use Debian.) I have not tried the SD patch though cuz I want it for inference of LLM.

gel-crabs commented 11 months ago

@gel-crabs that and tried FLASH_ATTENTION_INTERNAL_USE_RTN=1 pip install . it just say FFFFF result for the testing

(I use Debian.) I have not tried the SD patch though cuz I want it for inference of LLM.

Do you mean the unit testing? For that you need to export FLASH_ATTENTION_INTERNAL_UNIT_TEST_MODE=1 and FLASH_ATTENTION_INTERNAL_DETERMINISTIC=1.

You should also set your GPU_ARCHS to your GPU architecture (gfx1100, gfx1101, etc.) and try building with GCC and Clang. I can also only guarantee this will work on ROCM 5.7 and up.

For anything other than SD webui, you will likely have to create a forward function yourself, as it is a PyTorch extension and isn't integrated into PyTorch yet. The implementation is here, but keep in mind it requires my changes as well:

https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/11902/files

j-dominguez9 commented 11 months ago

I don't have the knowledge to contribute to this issue, but I'm really rooting for this support feature!

Kademo15 commented 11 months ago

You may get a HIP OOM at the end of generation if you're using a higher resolution than usual, as it needs to switch back to SDP at the end due to not supporting a head dim over 128.

Could you provide me what you found. What‘s the size you can go before it has to fallback to sdp.

feffy380 commented 11 months ago

I've been using the howiejayz/navi_support branch on here with stable-diffusion-webui for a few weeks now. The implementation is perfect.

On an RX 7800 XT, it speeds it/s up from 1.75 it/s to 2 it/s, all while massively decreasing VRAM usage.

I found larger resolutions benefit more. On SD1.5, 1088x1088 on an RX 7900 XTX went from 1.67 it/s to 3 it/s while 512px was a more modest 16 it/s to 18 it/s. VRAM usage also drops dramatically. Generating a 1088x1088 image goes from 18GB down to about 6GB. I don't see any spike at the end, though. Is it specific to SDXL maybe?

Note: If using a venv I found I had to build the wheel with the venv activated. Otherwise, the library complained about undefined symbols and failed to load.

gel-crabs commented 11 months ago

You may get a HIP OOM at the end of generation if you're using a higher resolution than usual, as it needs to switch back to SDP at the end due to not supporting a head dim over 128.

Could you provide me what you found. What‘s the size you can go before it has to fallback to sdp.

I think I explained wrong; it will always fall back to SDP at the very end of generation (after all the steps have finished), resolution doesn't factor into it.

What I meant is that the massively decreased VRAM usage will (currently..?) not allow you to use a higher resolution than with regular SDP attention, as the VRAM usage will jump back up at the end.

However, it most likely could... if AMD hooked up CK's Navi branch to their new xformers port. ;)

gel-crabs commented 10 months ago

Also note: the switch back to SDP at the end (I believe only with SDXL) can be prevented by switching from full VAE to TAESD, or (I assume) a tiled VAE implementation.

This allows a 1024x1024 image to be upscaled to 2048x2048 with SDXL on an RX 7800 XT with 16GB of VRAM.

vire70 commented 10 months ago

I see above someone recommended a way to use this with Stable Diffusion, is something similar required to get it working with Ooga text gen? I have flash attention compiled with howiejay/navi_support (I use 7900 XTX) but the program never detects that it is installed. Be great if someone could make a step by step guide at some point for those of us less savvy.

feffy380 commented 10 months ago

Also note: the switch back to SDP at the end (I believe only with SDXL) can be prevented by switching from full VAE to TAESD, or (I assume) a tiled VAE implementation.

This allows a 1024x1024 image to be upscaled to 2048x2048 with SDXL on an RX 7800 XT with 16GB of VRAM.

At least for this particular use-case I'm not sure falling back to SDP is the best option to begin with. AFAIK it doesn't implement any optimizations for radeon gpus, so Doggettx or sub-quadratic are more efficient despite being implemented in plain pytorch (Doggettx uses a bit less vram and handles low memory gracefully instead of crashing, sub-quadratic trades speed for far less memory usage).

gel-crabs commented 10 months ago

Also note: the switch back to SDP at the end (I believe only with SDXL) can be prevented by switching from full VAE to TAESD, or (I assume) a tiled VAE implementation. This allows a 1024x1024 image to be upscaled to 2048x2048 with SDXL on an RX 7800 XT with 16GB of VRAM.

At least for this particular use-case I'm not sure falling back to SDP is the best option to begin with. AFAIK it doesn't implement any optimizations for radeon gpus, so Doggettx or sub-quadratic are more efficient despite being implemented in plain pytorch (Doggettx uses a bit less vram and handles low memory gracefully instead of crashing, sub-quadratic trades speed for far less memory usage).

Yeah, I've already tried to get both working instead of plain SDP. I always hit a dead end because Doggettx/sub-quadratic aren't implemented directly in PyTorch itself. So we're pretty much just stuck here for now.

feffy380 commented 9 months ago

_flash_attn_forward() seems to be returning garbage softmax_lse values (7900xtx). Is that expected?

ewof commented 9 months ago

are u on the howiejay/navi_support branch?

feffy380 commented 9 months ago

Yes. The forward pass output is within rounding error of a plain pytorch implementation but the LSE varies wildly (mean value shown), without which we can't even substitute the missing backward pass for a pure pytorch implementation image

ZhenyaPav commented 9 months ago

python-flash-attention.tar.gz

I have installed flash attention using this PKGBUILD and am getting this error when trying to load a model using exllamav2 in text-generation-webui:

19:47:00-728633 ERROR    Failed to load the model.                                                      
Traceback (most recent call last):
  File "/home/zhenyapav/Projects/text-generation-webui/modules/ui_model_menu.py", line 242, in load_model_wrapper
    shared.model, shared.tokenizer = load_model(selected_model, loader)
                                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zhenyapav/Projects/text-generation-webui/modules/models.py", line 87, in load_model
    output = load_func_map[loader](model_name)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zhenyapav/Projects/text-generation-webui/modules/models.py", line 371, in ExLlamav2_loader
    from modules.exllamav2 import Exllamav2Model
  File "/home/zhenyapav/Projects/text-generation-webui/modules/exllamav2.py", line 5, in <module>
    from exllamav2 import (
  File "/home/zhenyapav/Projects/text-generation-webui/venv/lib/python3.11/site-packages/exllamav2/__init__.py", line 3, in <module>
    from exllamav2.model import ExLlamaV2
  File "/home/zhenyapav/Projects/text-generation-webui/venv/lib/python3.11/site-packages/exllamav2/model.py", line 21, in <module>
    from exllamav2.attn import ExLlamaV2Attention
  File "/home/zhenyapav/Projects/text-generation-webui/venv/lib/python3.11/site-packages/exllamav2/attn.py", line 19, in <module>
    import flash_attn
  File "/home/zhenyapav/Projects/text-generation-webui/venv/lib/python3.11/site-packages/flash_attn/__init__.py", line 3, in <module>
    from flash_attn.flash_attn_interface import flash_attn_func
  File "/home/zhenyapav/Projects/text-generation-webui/venv/lib/python3.11/site-packages/flash_attn/flash_attn_interface.py", line 4, in <module>
    import flash_attn_2_cuda as flash_attn_cuda
ImportError: /home/zhenyapav/Projects/text-generation-webui/venv/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZNK3c1017SymbolicShapeMeta18init_is_contiguousEv
mega-ice commented 9 months ago

ImportError: /home/zhenyapav/Projects/text-generation-webui/venv/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZNK3c1017SymbolicShapeMeta18init_is_contiguousEv

Did you compile bitsandbytes with RoCm support? you will need to uninstall the auto-installed version. For RDNA3 GPU try: (i am using this version of bitsandbytes on rocm 6.0.2, torch 2.3.0.dev20240219+rocm6.0)

git clone https://github.com/arlo-phoenix/bitsandbytes-rocm-5.6
ROCM_HOME=/opt/rocm ROCM_TARGET=gfx1100 make hip
pip install .
Beinsezii commented 9 months ago

I have installed flash attention using this PKGBUILD and am getting this error when trying to load a model using exllamav2 in text-generation-webui:

The navi flash attention branch won't work with exllamav2 regardless. It returns garbage if you bypass the v2.2.1 check.

ImportError: /home/zhenyapav/Projects/text-generation-webui/venv/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol

Clean rebuild should fix. git clean -fd; python setup.py clean; make clean; python setup.py bdist_wheel

Additionally make sure your pytorch and system rocm are the same. If your system is on 6.0 you'll have to use pytorch nightly then compile exllama, llama_cpp, + others for rocm 6 as text gen webui only provides 5.6 wheels.

I know pytorch statically links to rocm 5.6 but I believe things like pip install and the exllamav2 JIT will still look in /opt/rocm/ first.

ZhenyaPav commented 9 months ago

The navi flash attention branch won't work with exllamav2 regardless. It returns garbage if you bypass the v2.2.1 check.

What loaders do work then? AutoGPTQ?

Additionally make sure your pytorch and system rocm are the same. If your system is on 6.0 you'll have to use pytorch >nightly then compile exllama, llama_cpp, + others for rocm 6 as text gen webui only provides 5.6 wheels.

I'm having some issues installing exllama from the repo:

ModuleNotFoundError: No module named 'torch'

Though torch nightly is installed, and I am able to import torch in python console.

Beinsezii commented 9 months ago

What loaders do work then? AutoGPTQ?

AFAIK just Transformers and even that's super rough. I see my VRAM go down but it actually runs slower than without any flash attention...

If you're on ROCm 6, therocm_enabled bitsandbytes branch should compile and run allowing 8/4 bit models in Transformers, but it seems to randomly cut some responses short. pytest shows a couple tests as failing due to produced infs so I'm wondering if that's why. Also it's slow, 1/10th the speed of Exllama for an equivalent bit depth.

If you're on 5.7 or earlier I think you need to checkout the rocm_enabled branch to the commit before they changed the HIPBLAS enum names.

I'm having some issues installing exllama from the repo:

ModuleNotFoundError: No module named 'torch'

Yea some builds of rocm torch nightly aren't correctly picked up by pip as satisfying the torch requirement. I have no idea why, but you can just comment out "torch" from the setup.py's dependencies list and it'll install fine. Or you could try an earlier torch nightly build, some of them don't have that issue. It seems to come and go. Maybe update pip?

Also if you're installing exllama2 from the repo make sure to check out on a tag instead of master as master contains breaking API changes oobabooga-webui doesn't account for yet.

TL;DR: There's no good way to get long context on AMD cards yet. The best option is still exllama2 and just accepting the high memory usage and long 8bit cache builds.

gel-crabs commented 9 months ago

Big update for Stable Diffusion WebUI users!!

So as it turns out, it was actually super easy to replace SDP with Doggettx/sub-quadratic the whole time, I was just looking in the wrong place. XFormers does the exact same thing, just in the attnblock forward instead of the attention forward.

11902.patch.txt

Above is an updated version of the WebUI patch if you haven't applied it already (rename it to 11902.patch).

If you've already applied it, you can just replace sd_hijack_optimizations.py with this copy (rename it to sd_hijack_optimizations.py):

sd_hijack_optimizations.py.txt

Note: I chose Sub-quadratic as Doggettx has similar VRAM usage as SDP, and it only switches at the end of generation anyway so VRAM use matters more than speed here.

xhluca commented 9 months ago

FYI You can find the shader ISA (gfxOOOO number) on techpowerup, e.g.: https://www.techpowerup.com/gpu-specs/radeon-rx-6900-xt.c3481

You can see it's a gfx1030.

seems like this issue is mainly concerining RX7000 gpus; does that mean 6000 gpus won't be supported?

Beinsezii commented 9 months ago

Or simply

rocminfo | grep Name

which will give you the board names for all ROCm devices

ewof commented 9 months ago

will varlen fwd be added?

Beinsezii commented 9 months ago

I made a custom attention processor for use in Diffusers which falls back to SDP on > 128 head dims. Additionally I have a small guide going over setup.

Seems to be about +30% throughput in SDXL 1024², ramping up to +80% at the comically large 3840x2160.

Beinsezii commented 9 months ago

If anyone wants to use Flash Attention wherever Torch 2 SDP works, you can simply monkey patch it in before the call

import torch

if "AMD" in torch.cuda.get_device_name() or "Radeon" in torch.cuda.get_device_name():
    try:
        from flash_attn import flash_attn_func

        sdpa = torch.nn.functional.scaled_dot_product_attention

        def sdpa_hijack(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
            if query.shape[3] <= 128 and attn_mask is None:
                hidden_states = flash_attn_func(
                    q=query.transpose(1, 2),
                    k=key.transpose(1, 2),
                    v=value.transpose(1, 2),
                    dropout_p=dropout_p,
                    causal=is_causal,
                    softmax_scale=scale,
                ).transpose(1, 2)
            else:
                hidden_states = sdpa(
                    query=query,
                    key=key,
                    value=value,
                    attn_mask=attn_mask,
                    dropout_p=dropout_p,
                    is_causal=is_causal,
                    scale=scale,
                )
            return hidden_states

        torch.nn.functional.scaled_dot_product_attention = sdpa_hijack
        print("# # #\nHijacked SDPA with ROCm Flash Attention\n# # #")
    except ImportError as e:
        print(f"# # #\nCould not load Flash Attention for hijack:\n{e}\n# # #")
else:
    print(f"# # #\nCould not detect AMD GPU from:\n{torch.cuda.get_device_name()}\n# # #")

Then whenever something downstream requests the torch 2 sdp attention it should instead flash attention where supported.

sancspro commented 8 months ago

Big update for Stable Diffusion WebUI users!!

So as it turns out, it was actually super easy to replace SDP with Doggettx/sub-quadratic the whole time, I was just looking in the wrong place. XFormers does the exact same thing, just in the attnblock forward instead of the attention forward.

11902.patch.txt

Above is an updated version of the WebUI patch if you haven't applied it already (rename it to 11902.patch).

If you've already applied it, you can just replace sd_hijack_optimizations.py with this copy (rename it to sd_hijack_optimizations.py):

sd_hijack_optimizations.py.txt

Note: I chose Sub-quadratic as Doggettx has similar VRAM usage as SDP, and it only switches at the end of generation anyway so VRAM use matters more than speed here.

Hi there. I tried to implement this method for flash attention, and I am getting this error in webui: RuntimeError: FlashAttention forward only supports head dimension at most 128

I am on latest ROCM and using 7800xt. Flash attention works with another repo: quickdif, but not with sd webui. Any suggestions to fix this? I mean how do we even limit the head dim in webui?

feffy380 commented 8 months ago

@sancspro The patch is missing a check for head dim 128. You can add this to flash_attn_attention_forward to fall back to sub_quad_attention:

 def flash_attn_attention_forward(self, x, context=None, mask=None, **kwargs):
     h = self.heads
     q_in = self.to_q(x)
     context = default(context, x)

+    if q_in.shape[-1] // h > 128:
+        return sub_quad_attention_forward(self, x, context, mask, **kwargs)
sancspro commented 8 months ago

Hi @feffy380 Thanks for your response. I added the code as you suggested and now, the error is gone. But the generated image is unusable/noise. I thought maybe sub-quad attention is causing this but when I loaded webui with --opt-sub-quad-attention, it works fine.

Beinsezii commented 8 months ago

subquad uses transposed tensors compared to flash attention so that needs to be accounted for or else it'll be garbage.

If it's FA producing the garbage, make sure the q/k/v are ordered batch seq nhead dim. Other attentions use different orderings and FA won't err of they're mixed up, it'll just produce soup.

sancspro commented 8 months ago

I am unable to make it work with sd webui. Either throws head dimension above 128 error or just generates garbage.

I am a novice and sorry for this question. What is head dim with respect to sd webui? Is it related to the image resolution ?

On Mint, 7800xt, 7800x3d, 32GB RAM