vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
26.96k stars 3.96k forks source link

Flash Attention V2 #485

Closed nivibilla closed 1 year ago

nivibilla commented 1 year ago

https://github.com/Dao-AILab/flash-attention

Flash attention v2 was released claiming 2x speedups. Making an issue to remind myself to have a look at it. And also if anyone else wants to try implement it.

chenrui17 commented 1 year ago

I use benchmarks/benchmark_throughput.py to test flash attention V2, but it doesn't seem to have any effect. my test step is like this,

test time is like this,

Further analysis of performance, i found that the replaced part (flash attention V2) cost is too small, only at the beginning of the execution, i am confused , for flash attention V2, what can we do for vllm?

tmm1 commented 1 year ago
  • update xformers to latest version

hi, which version specifically? i don't think flash v2 support has been released yet, so you would have to install from git. also there are still some open PRs to bump xformers to flash-attn v2.0.4 bugfix release (https://github.com/facebookresearch/xformers/pull/816).

tmm1 commented 1 year ago
  • modify the line to self.attn_op = xops.fmha.flash.FwOp()
  • python3 benchmark_throughput.py --dataset=./ShareGPT_V3_unfiltered_cleaned_split.json --model=/huggingface_data/llama-7b-hf/ --tokenizer=hf-internal-testing/llama-tokenizer --num-prompts=500

I tried this as well, and there was no improvement in the benchmarks after switching to flash-attn v2.

I will try to profile the benchmark script.

Zhuqln commented 1 year ago

modify the line to self.attn_op = xops.fmha.flash.FwOp()

i dont think this one really works. because flash-attn's another important feature is to decrease the highly gpu-memory usage in super long-context like more than 5k. when i set that line and run inference . i dont see any changes on memory usage.

WoosukKwon commented 1 year ago

Hi @nivibilla, thanks for submitting the issue. The latest version of xformers now uses the FlashAttention-V2 algorithm, so vLLM also now takes advantage of it. Please upgrade vLLM to v0.1.4.

@tmm1 @Zhuqln To my understanding, the overall speedup should depend on your workload. At the inference time, FlashAttention is only used for the prompt inputs, and never used for the decoding inputs. For many workloads, the decoding stage takes a majority of the total execution time, so changing to FlashAttention V2 may not give a notable speedup. However, for other workload like text summarization where the prompts are very long, I believe computing attention for the prompt inputs will take a significant portion of the execution time, and thus FlashAttention V2 will have a huge impact on the overall performance.

nivibilla commented 1 year ago

@WoosukKwon thanks for the explanation!

tmm1 commented 1 year ago

The latest version of xformers now uses the FlashAttention-V2 algorithm, so vLLM also now takes advantage of it. Please upgrade vLLM to v0.1.4.

Hi, this is inaccurate since the code is still forcing xops.fmha.cutlass.FwOp to be used. If you want to take advantage of FA2, you would need to switch to xops.fmha.flash.FwOp

See benchmark results in https://github.com/facebookresearch/xformers/issues/832

zhaoyang-star commented 1 year ago

Hi @nivibilla, thanks for submitting the issue. The latest version of xformers now uses the FlashAttention-V2 algorithm, so vLLM also now takes advantage of it. Please upgrade vLLM to v0.1.4.

@tmm1 @Zhuqln To my understanding, the overall speedup should depend on your workload. At the inference time, FlashAttention is only used for the prompt inputs, and never used for the decoding inputs. For many workloads, the decoding stage takes a majority of the total execution time, so changing to FlashAttention V2 may not give a notable speedup. However, for other workload like text summarization where the prompts are very long, I believe computing attention for the prompt inputs will take a significant portion of the execution time, and thus FlashAttention V2 will have a huge impact on the overall performance.

Thanks for the details @WoosukKwon . I just have a question. Why FlashAttention could not be used for decoding phase?

learning-chip commented 1 year ago

Why FlashAttention could not be used for decoding phase?

Its tiling strategy is not optimized for Q with seqlen=1 https://github.com/Dao-AILab/flash-attention/issues/427#issuecomment-1668257877

Lvjinhong commented 9 months ago

你好@nivibilla,感谢您提交问题。最新版本的 xformers 现在使用 FlashAttention-V2 算法,因此 vLLM 现在也利用了它。请将vLLM升级到v0.1.4。 @tmm1 @Zhuqln据我了解,整体加速应该取决于您的工作负载。在推理时,FlashAttention 仅用于提示输入,从不用于解码输入。对于许多工作负载,解码阶段占用了总执行时间的大部分,因此更改为 FlashAttention V2 可能不会带来显着的加速。然而,对于其他工作负载,例如提示很长的文本摘要,我相信提示输入的计算注意力将占用执行时间的很大一部分,因此 FlashAttention V2 将对整体性能产生巨大影响。

感谢您提供详细信息@WoosukKwon。我只是有一个问题。为什么FlashAttention不能用于解码阶段?

I'm delighted to engage in this discussion. Your report has been immensely helpful, but I do have some questions. For instance, I'm curious to know if there's a performance comparison available between trtLLM and vLLM. Such information would be greatly beneficial in guiding my decision on which framework to choose.

matanhol commented 7 months ago

Hi @nivibilla, thanks for submitting the issue. The latest version of xformers now uses the FlashAttention-V2 algorithm, so vLLM also now takes advantage of it. Please upgrade vLLM to v0.1.4. @tmm1 @Zhuqln To my understanding, the overall speedup should depend on your workload. At the inference time, FlashAttention is only used for the prompt inputs, and never used for the decoding inputs. For many workloads, the decoding stage takes a majority of the total execution time, so changing to FlashAttention V2 may not give a notable speedup. However, for other workload like text summarization where the prompts are very long, I believe computing attention for the prompt inputs will take a significant portion of the execution time, and thus FlashAttention V2 will have a huge impact on the overall performance.

Thanks for the details @WoosukKwon . I just have a question. Why FlashAttention could not be used for decoding phase?

you assume that in summarization task most of the workload is by decoding the input. in my experimentation I saw that the scale of generation is much bigger. so, if you generate only 1-5 token then most of the workload is decoding input, there will be dependency on input length and flash attention 2 will be advantageous (as it linear in input length while naive implementation is exponential in input length). but if you generate a considerable amount of tokens, then that factor is prominent, the input decoding is negligible, and flash attention 2 has no power here. (usually when you have long text you want a longer summarization. it doesn't make sense to summarize 1000 words article by 5 tokens) attached link to the simulation. please LMK if you have any comments.

https://github.com/matanhol/summarization_with_flash_attn_2_simulation

brando90 commented 4 weeks ago

I tried installing vllm with flash attn but it didn't work, my attempts:

Install flash attention:
```bash
# my current vllm setup without flash
# pip install --upgrade pip
# pip install torch==2.2.1
# pip install vllm==0.4.1

# flash attn https://amzn-aws.slack.com/archives/C06Q26TNN8G/p1724182667464149
# flash-attn>=2.5.8
# pip install flash-attn
# Collabs's setup with flash
# vllm                              0.5.4
# vllm-flash-attn                   2.6.1
# flash-attn                        2.6.3
# torch                             2.4.0
# Python 3.10.8 

# try to install flash attn in a new py env
python3.11 -m venv ~/.virtualenvs/flash_attn_test_py10
source ~/.virtualenvs/flash_attn_test/bin/activate
pip install --upgrade pip
pip install -e ~/snap-cluster-setup

pip list | grep vllm
pip list | grep torch
pip list | grep flash-attn
pip list | grep vllm-flash-attn

# # didn't work
# pip install torch==2.2.1
# pip install vllm==0.4.1
# MAX_JOBS=4 pip install flash-attn --no-build-isolation --force

# this installed flash but vllm didn't say in it's output it was using it
pip install torch==2.4.0
pip install vllm==0.5.4
pip install flash-attn==2.6.3
pip install vllm-flash-attn==2.6.1

python ~/snap-cluster-setup/py_src/evals/boxed_acc_eval.py --model internlm/internlm2_5-1_8b --hf_gen_type vllm --path_2_eval_dataset ~/snap-cluster-setup/data/MATH/test --max_tokens 2048 --batch_size 100 --end 100 -n 1 --shuffle True --mode dryrun 2>&1 | tee $LOG_FILE && echo "Log file created at: $LOG_FILE"

# later try with py 3.10
# python3xxx -m venv ~/.virtualenvs/flash_attn_test_py10
# source ~/.virtualenvs/flash_attn_test_py10/bin/activate
# pip install --upgrade pip
# pip install -e ~/snap-cluster-setup
# pip install torch==2.4.0
# pip install vllm==0.5.4
# pip install flash-attn==2.6.3
# pip install vllm-flash-attn==2.6.1
brando90 commented 4 weeks ago

my setting is python 3.11, that is what I really want/need.

brando90 commented 4 weeks ago

related vllm general issues for vllm ver: https://github.com/vllm-project/vllm/issues/2747