Open lhl opened 7 months ago
After taking a deep look into the code and testing Flash Attention support on AMD GPUs here is what I found:
AMD Instinct GPUs, gfx90a and gfx942 (MI210, MI250, MI300), support Flash Attention by way of specially written Composable Kernel libraries. Although I haven't tested this myself it is working and there are performance numbers on the 2-3x speedup vLLM gives you using CK Flash Attention.
Radeon RDNA3 GPUs, 7900 XTX and W7900 (gfx1100), lack the nessecary Composable Kernel libraries to use the above mentioned Flash Attention mechanism and thus the engineers at AMD opted for these GPUs to use an implemenation of Flash Attention written in OpenAI's Triton. This Triton Flash Attention is supposed to be working, but all tests I've done (usuing various different branches and docker builds) and using VLLM_USE_TRITON_FLASH_ATTN=1
have the same "stack frame size exceeds limit" issue while trying to compile doing the Triton JIT compile at runtime. I am sure the compile is not failing due to system resources as I have tested this using the Radeon Pro W7900 on two powerful systems, Ryzen 9 7950x w/ 64GBs of RAM and a Threadripper Pro 5975wx w/ 128GBs of RAM, but in both cases the triton compile takes a really long time (upwards of several hours) and still fails with the same stack frame size error (see screenshot).
Flash Attention forward pass support for RDNA3 was added thanks to howiejay however this implementation no longer works in my testing as it fails to run the hipify_python patch and build on newer versions of pytorch+rocm (tried on rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1
and rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1
).
In summary the only way it seems to get vLLM working on Radeon and Radeon Pro graphics cards at the moment is to build without CK Flash Attention support BUILD_FA="0"
and disable the Triton Flash Attention implemenation VLLM_USE_TRITON_FLASH_ATTN=0
. This results in vLLM running, but you do not get any of the speed ups that vLLM is known for and in my testing inference using vLLM is the same or slower than things like llama.cpp and Ollama.
The vLLM repos I've already tried are:
https://github.com/vllm-project/vllm (main
branch)
https://github.com/ROCm/vllm (main
, bf16_temp_fix_navi
, TunableOp_Integration_ROCm6.0
branches)
https://github.com/hongxiayang/vllm (main
branch)
Commands used to run vLLM docker image and server were as follows (tried a few other variations of the below commands like changing smh-size or --max-model-len with no luck):
# Run vllm-rocm Docker image
docker run -it --network=host --device=/dev/kfd --device=/dev/dri --group-add=video --ipc=host \
--cap-add=SYS_PTRACE --security-opt seccomp=unconfined --shm-size 16G --name vllm-rocm -v /home/${USER}/Downloads/models:/app/model \
vllm-rocm bash
# Run vllm api server
VLLM_USE_TRITON_FLASH_ATTN=1 CUDA_VISIBLE_DEVICES=0 python -m vllm.entrypoints.openai.api_server --max-model-len 3072 --download-dir /app/model --quantization=gptq --tensor-parallel-size=1 --enforce-eager --trust-remote-code --dtype=auto --kv-cache-dtype=auto --quantization-param-path=None --device=cuda --block-size=16 --model TechxGenus/Meta-Llama-3-70B-Instruct-GPTQ
Asking that the engineers at AMD look into this and assist in troubleshooting/getting this working for Radeon GPUs (Navi3).
if possible, can you try building triton from source?
That won't work I think. There's a related Flash Attention discussion on gfx1100 here: https://github.com/ROCm/aotriton/issues/16 although according to this, Navi support was upstreamed last month and the appropriate place to file any navi31 Triton issues is the main repo: https://github.com/openai/triton
(The vLLM bug atm is just that it's not checking for gfx1100 correctly, it shouldn't be trying to use the Triton FA at all?)
The howiejay branch should build fine on the latest torch stable running ROCm 6. I have py3.11 and py3.12 wheels built against gfx1100 and ROCm 6.0 here. All I run is
pip wheel git+https://github.com/ROCm/flash-attention@howiejay/navi_support --no-deps
in my virtualenvs to produce the wheels.
So I built vLLM with defaults then set VLLM_USE_TRITON_FLASH_ATTN=0 at runtime. On unquantized Llama3 8B I peaked at something like 1550 T/S with BS=96 and 0.95 memory allocation on a 7900 XTX 24G. 400 token response with a few hundred in context. Seems okay-ish?
Upate: There's an internal gate against using the CK FA for Navi even if its installed because there's no varlen_fwd()
support. You can build and install the howiejay flash-attn fine but it seems to only be useful for diffusion models atm.
Additionally I built ROCM/triton from source as of an hour ago and it still just sits peaking one thread for a small eternity before eventually being killed for blowing up the stack. I guess a person could try to increase the stack size but I really feel like something's not working...
I think I narrowed it to this autotune:
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "waves_per_eu": 2, "PRE_LOAD_V": False}, num_stages=1, num_warps=4)
Disabling that and I can run without VLLM_USE_TRITON_FLASH_ATTN=0
. I'm using triton nightly as of an hour ago to make sure it has any possible Navi fixes. Though if anything it feels slower? I'll try stable triton in a bit.
Patch on top of v0.4.2 if someone else wants to play with it.
diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py
index 11476641..d5f6bbec 100644
--- a/vllm/attention/ops/triton_flash_attention.py
+++ b/vllm/attention/ops/triton_flash_attention.py
@@ -219,16 +219,16 @@ def _attn_fwd_inner(
num_stages=1,
num_warps=8,
),
- triton.Config(
- {
- "BLOCK_M": 128,
- "BLOCK_N": 128,
- "waves_per_eu": 2,
- "PRE_LOAD_V": False,
- },
- num_stages=1,
- num_warps=4,
- ),
+ # triton.Config(
+ # {
+ # "BLOCK_M": 128,
+ # "BLOCK_N": 128,
+ # "waves_per_eu": 2,
+ # "PRE_LOAD_V": False,
+ # },
+ # num_stages=1,
+ # num_warps=4,
+ # ),
triton.Config(
{
"BLOCK_M": 256,
The howiejay branch should build fine on the latest torch stable running ROCm 6. I have py3.11 and py3.12 wheels built against gfx1100 and ROCm 6.0 here. All I run is
pip wheel git+https://github.com/ROCm/flash-attention@howiejay/navi_support --no-deps
in my virtualenvs to produce the wheels.
~So I built vLLM with defaults then set VLLM_USE_TRITON_FLASH_ATTN=0 at runtime. On unquantized Llama3 8B I peaked at something like 1550 T/S with BS=96 and 0.95 memory allocation on a 7900 XTX 24G. 400 token response with a few hundred in context. Seems okay-ish?~
Upate: There's an internal gate against using the CK FA for Navi even if its installed because there's no
varlen_fwd()
support. You can build and install the howiejay flash-attn fine but it seems to only be useful for diffusion models atm.Additionally I built ROCM/triton from source as of an hour ago and it still just sits peaking one thread for a small eternity before eventually being killed for blowing up the stack. I guess a person could try to increase the stack size but I really feel like something's not working...
the upstreaming triton support navi3 but attention performance is slow
Alright I tried with stable triton and the ROCm triton fork. My patch only helped the official nightly run without hanging.
pip uninstall pytorch-triton-rocm -y; pip install --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly --no-deps
There might be more configs that need to be disabled to run stable triton? A person could maybe just disable every config with a block dim ≥ 128 and it'd probably work everywhere. I think navi favors the small ones anyways?
I also found triton is indeed much faster than naive once you stack the context.
Is there a solution available? I also encountered this issue with AMD w6800
I'd wager it's related to https://github.com/ROCm/triton/issues/596
Is there a solution available? I also encountered this issue with AMD w6800
Commenting out every autotune with a block size ≥128 allows it to compile using pytorch-triton-rocm==2.3.1
for me on gfx1100.
@Beinsezii Thank you very much for your help. After operating according to your method, there will be no error before the error message (error: triton_flash-attention. py: 211:0: stack frame size). However, the answer generation process runs very slowly, which is equivalent to the speed at VLLM-USE-TRITON-FLASH.ATTN=0. Is it a problem with my graphics card? AMD W6800
The W6800 has no WMMA accelerators so I'm not sure it'll be faster. It should still use less memory for long context models though.
With https://github.com/ROCm/triton/issues/596 closed I decided to rebuild build triton-lang/triton
and was able to run VLLM_USE_TRITON_FLASH_ATTN=1
on an unmodified vllm 0.4.2 + gfx1100
The problem is that rocblas is not supported on navi architectures by rocm. Hence FA wont work in general I think.
Your current environment
🐛 Describe the bug
I'm able to built the ROCM docker image for AMD via the latest docs: https://docs.vllm.ai/en/latest/getting_started/amd-installation.html#option-1-build-from-source-with-docker-recommended
I am using a W7900 (RDNA3; navi31; gfx1100) and therefore use
BUILD_FA="0"
sans Flash Attention.When I run any script (say
benchmarks/benchmark_latency.py
), I get this error:It's trying to use Triton which seems to use an implementation of flash attention?
Stepping through the code it goes through
selector.py
:And that sends it to the ROCm backend:
In the backend, there is a switch for
navi3x
: https://github.com/vllm-project/vllm/blob/d6f4bd7cddc9546c38568c92c3772d22940a09f2/vllm/attention/backends/rocm_flash_attn.py#L167It sets
self.use_naive_attn = True
iftorch.cuda.get_device_capability()[0] == 11
(gfx11xx) - so far so good, but that branch only executes ifself.use_triton_flash_attn
isfalse
which is set byVLLM_USE_TRITON_FLASH_ATTN
and defaults to True.So, in order to get this running you need to have
VLLM_USE_TRITON_FLASH_ATTN=0
in your env.This isn't in the docs, or set by default when you
BUILD_FA="0"
Presumably, the correct way to fix is for the ROCm implmentation to do correct navi3x checking and set the appropriate lib/path to use based on which kernel is currently support?