vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
22.12k stars 3.12k forks source link

[Performance]: [Automatic Prefix Caching] When hitting the KV cached blocks, the first execute is slow, and then is fast. #5339

Open soacker opened 3 weeks ago

soacker commented 3 weeks ago

Proposal to improve performance

No response

Report of performance regression

I load llama3-70b into 4 gpus.

I record the common_computed_block_nums and show its context_lens_tensor. Also, I record the model execute time. Interestingly, I find that (1)when request first hit the common prefix cached, the model execute is slow. And except the first hitting, other request hitting the prefix cached is fast. (2) no hitting prefix cached seems more fast.

I know that using KV cached means context attention forward within prefill stage, and no KV cache means full attention forward.

no hitting the prefix cached.

INFO 06-04 19:27:53 block_manager_v1.py:253] Automatic prefix caching is enabled. Processed prompts: 0%| | 0/6 [00:00<?, ?it/s] ########## common_computed_block_nums: ########## common_computed_block_nums: ########## common_computed_block_nums: ########## common_computed_block_nums: ########## common_computed_block_nums: ########## common_computed_block_nums: ############ context_lens_tensor: tensor([0, 0, 0, 0, 0, 0], device='cuda:0', dtype=torch.int32)

model_executable time: 0.039289

First hitting the prefix cached.

########## common_computed_block_nums: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] ############ context_lens_tensor: tensor([208], device='cuda:0', dtype=torch.int32)

model_executable time: 2.093436

Second hitting the prefix cached.

########## common_computed_block_nums: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] ########## common_computed_block_nums: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] ########## common_computed_block_nums: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] ########## common_computed_block_nums: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] ########## common_computed_block_nums: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] ########## common_computed_block_nums: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] ############ context_lens_tensor: tensor([208, 208, 208, 208, 208, 208], device='cuda:0', dtype=torch.int32)

model_executable time: 0.042899

Can somebody show me its reason, and are there some improvements?

Misc discussion on performance

No response

Your current environment (if you think it is necessary)

The output of `python collect_env.py`
robertgshaw2-neuralmagic commented 3 weeks ago

Are you running with Xformers backend or FlashAttention backend?

soacker commented 3 weeks ago

Are you running with Xformers backend or FlashAttention backend?

Yeah. I use FlashAttention backend.

Amelia26345 commented 2 weeks ago

I encountered a similar problem. Greedy sampling, when the prompt length is 2000, output_len=1, the input request has no common prefix at all, all requests are different. The inference results of the first 7 requests are consistent with and without prefix caching, and the inference speed is also similar (3.4 token/s). The inference results of the 8th and subsequent requests are inconsistent with and without prefix caching, and the inference speed is particularly fast (17.6 token/s).please take a look at this bug, thank you ​vllm version: 0.5.0. 1*A100-40G , llama2-13b