Closed hmellor closed 6 months ago
CC @robertgshaw2-neuralmagic (the tweet said the feature was added by Neural Magic, so you might have some insight into this feature)
Will take a look at this case
I'm also interested in this issue so I benchmarked today using the latest main branch, which already uses flash-attn kernel for prefix caching. But even I've verified cache hit in prefix cache, I also found no speedup by running the above script. I'll also investigate a bit.
cc @SageMoore fyi
I am not sure what GPU this is, but on an A100, we can do ~15000 prefill tokens/sec at fp16. So even a 2000 token prefill should only take 0.13 seconds to process. Since APC skips prefill computation, there are only 0.5s worth of time that can be optimized in this case. As a result, I would not really expect to see a speedup in this case (+ in fact there is some overhead associated with managing another layer of indirection)
APC really is useful for cases with long shared prefills and short decodes, such as:
Thanks for the good hint. I instead let the script report the latency of every request instead of the total time, and here are the results on L4 GPU:
w/o APC
stop - start = 28.363408592998894
CompletionUsage(completion_tokens=453, prompt_tokens=16, total_tokens=469)
stop - start = 13.253794727999775
CompletionUsage(completion_tokens=211, prompt_tokens=485, total_tokens=696)
stop - start = 15.47434264399999
CompletionUsage(completion_tokens=245, prompt_tokens=712, total_tokens=957)
stop - start = 22.77607062900279
CompletionUsage(completion_tokens=357, prompt_tokens=974, total_tokens=1331)
stop - start = 25.096272947001125
CompletionUsage(completion_tokens=392, prompt_tokens=1348, total_tokens=1740)
stop - start = 2.3558405980002135
CompletionUsage(completion_tokens=30, prompt_tokens=1757, total_tokens=1787)
stop - start = 2.8636473680016934
CompletionUsage(completion_tokens=37, prompt_tokens=1806, total_tokens=1843)
w. APC
stop - start = 28.40403065999999
CompletionUsage(completion_tokens=453, prompt_tokens=16, total_tokens=469)
stop - start = 13.463971014996787
CompletionUsage(completion_tokens=211, prompt_tokens=485, total_tokens=696)
stop - start = 15.43624263699894
CompletionUsage(completion_tokens=245, prompt_tokens=712, total_tokens=957)
stop - start = 22.343338724000205
CompletionUsage(completion_tokens=355, prompt_tokens=974, total_tokens=1329)
stop - start = 25.549687523998728
CompletionUsage(completion_tokens=403, prompt_tokens=1346, total_tokens=1749)
stop - start = 1.933658195001044
CompletionUsage(completion_tokens=30, prompt_tokens=1766, total_tokens=1796)
stop - start = 2.3811154130016803
CompletionUsage(completion_tokens=37, prompt_tokens=1815, total_tokens=1852)
It seems align to what you analyzed.
I am not sure what GPU this is
My results came from A100 40GB with --gpu-memory-utilization 0.5
and --enforce-eager
(both of which would make my experiments slower).
there are only 0.5s worth of time that can be optimized in this case
Ok, so it's simply a case of my test not being suitable. If I was running a model with a more expensive prefill (i.e. bigger than 7B) and with longer prompts, I'd start to be able to observe the difference in a single conversation (albeit a subtle difference).
Presumably there is also a concurrency benefit too, because the slot that would have been scheduled to execute the cached prefill can be used to process the prefill (or decoding) of a different request?
The key thing for automatic prefix caching to have a sizable improvement is that the ratio between input token length and output token length should be VERY VERY large (ideally more than 100x difference).
This is a very strong workload requirement, and such type of workload only commonly occurs in specific applications (e.g. asking questions to a very long software manual).
I ran a better test and have an interesting graph:
Regardless of first prompt size, there seems to be a large fixed cost on turn 1 (i.e. the second turn), but not the subsequent turns.
@SageMoore:
PrefixCachingBlockAllocator
this comes from?@hmellor - this is caused by Triton jitting. The first time the server runs the context_fwd_attention, Triton jits which slows us down. Have been meaning to finish off a PR that runs the JITing durin profiling, but has become lower priority since if you use latest main with flash attention this issue is resolved b/c it uses the flash attn kernels rather than triton for context_fwd_attn
note: this will happen once per instantiation of the server
if you use latest main with flash attention this issue is resolved
Is that the flash attention from the "pip install vllm-flash-attn
for better performance." info log I've seen?
Yes. You can just pip install vllm-flash-attn and make sure seeing the log Using FlashAttention-2
when launching the server.
I think its now installed automatically
https://github.com/vllm-project/vllm/blob/main/setup.py#L356
Ok, thanks for clearing that up for me!
How should we enable APC in the offline chat?
I'm interested in the automatic prefix caching feature for multi-turn conversations but I can't seem to observe a performance improvement when prefix caching is enabled. This tweet from @vllm_project indicates that automatic prefix caching should benefit this use case.
I am using the following commands to start the vLLM server:
And the following script to simulate a multi turn conversation from a user:
With automatic prefix caching disabled I see:
And with automatic prefix caching enabled I see:
Is this expected?