vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
31.25k stars 4.75k forks source link

[Performance]: Automatic Prefix Caching in multi-turn conversations #4917

Closed hmellor closed 6 months ago

hmellor commented 6 months ago

I'm interested in the automatic prefix caching feature for multi-turn conversations but I can't seem to observe a performance improvement when prefix caching is enabled. This tweet from @vllm_project indicates that automatic prefix caching should benefit this use case.

I am using the following commands to start the vLLM server:

python -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-2-7b-chat-hf --port 7001 --gpu-memory-utilization 0.5 --disable-log-requests --enforce-eager

python -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-2-7b-chat-hf --port 7001 --gpu-memory-utilization 0.5 --disable-log-requests --enforce-eager --enable-prefix-caching

And the following script to simulate a multi turn conversation from a user:

import time
from openai import OpenAI

user_messages = [
    "Tell me your ten favourite films",
    "Who directed each of these films?",
    "Which director has the most experience?",
    "What other films has this director directed?",
    "Do these films have anything in common?",
    "Which of those films is the oldest?",
    "How old was the director when this was released?",
]

client = OpenAI(api_key="api_key", base_url="http://localhost:7001/v1")

messages = []

start = time.perf_counter()
for user_message in user_messages:
    messages.append(dict(role="user", content=user_message))
    output = client.chat.completions.create(
        messages=messages,
        model="meta-llama/Llama-2-7b-chat-hf",
        temperature=0.0,
    )
    print(output.usage)
    assistant_message = output.choices[0].message
    messages.append(dict(role=assistant_message.role, content=assistant_message.content))
stop = time.perf_counter()
print(f"{stop - start = }")

With automatic prefix caching disabled I see:

$ python test.py 
CompletionUsage(completion_tokens=598, prompt_tokens=16, total_tokens=614)
CompletionUsage(completion_tokens=238, prompt_tokens=629, total_tokens=867)
CompletionUsage(completion_tokens=321, prompt_tokens=882, total_tokens=1203)
CompletionUsage(completion_tokens=465, prompt_tokens=1219, total_tokens=1684)
CompletionUsage(completion_tokens=446, prompt_tokens=1700, total_tokens=2146)
CompletionUsage(completion_tokens=212, prompt_tokens=2162, total_tokens=2374)
CompletionUsage(completion_tokens=58, prompt_tokens=2393, total_tokens=2451)
stop - start = 35.292656753677875

And with automatic prefix caching enabled I see:

$ python test.py 
CompletionUsage(completion_tokens=598, prompt_tokens=16, total_tokens=614)
CompletionUsage(completion_tokens=238, prompt_tokens=629, total_tokens=867)
CompletionUsage(completion_tokens=321, prompt_tokens=882, total_tokens=1203)
CompletionUsage(completion_tokens=468, prompt_tokens=1219, total_tokens=1687)
CompletionUsage(completion_tokens=459, prompt_tokens=1703, total_tokens=2162)
CompletionUsage(completion_tokens=197, prompt_tokens=2178, total_tokens=2375)
CompletionUsage(completion_tokens=60, prompt_tokens=2394, total_tokens=2454)
stop - start = 35.605276009999216

Is this expected?

hmellor commented 6 months ago

CC @robertgshaw2-neuralmagic (the tweet said the feature was added by Neural Magic, so you might have some insight into this feature)

robertgshaw2-neuralmagic commented 6 months ago

Will take a look at this case

comaniac commented 6 months ago

I'm also interested in this issue so I benchmarked today using the latest main branch, which already uses flash-attn kernel for prefix caching. But even I've verified cache hit in prefix cache, I also found no speedup by running the above script. I'll also investigate a bit.

robertgshaw2-neuralmagic commented 6 months ago

cc @SageMoore fyi

robertgshaw2-neuralmagic commented 6 months ago

I am not sure what GPU this is, but on an A100, we can do ~15000 prefill tokens/sec at fp16. So even a 2000 token prefill should only take 0.13 seconds to process. Since APC skips prefill computation, there are only 0.5s worth of time that can be optimized in this case. As a result, I would not really expect to see a speedup in this case (+ in fact there is some overhead associated with managing another layer of indirection)

APC really is useful for cases with long shared prefills and short decodes, such as:

comaniac commented 6 months ago

Thanks for the good hint. I instead let the script report the latency of every request instead of the total time, and here are the results on L4 GPU:

w/o APC

stop - start = 28.363408592998894
CompletionUsage(completion_tokens=453, prompt_tokens=16, total_tokens=469)
stop - start = 13.253794727999775
CompletionUsage(completion_tokens=211, prompt_tokens=485, total_tokens=696)
stop - start = 15.47434264399999
CompletionUsage(completion_tokens=245, prompt_tokens=712, total_tokens=957)
stop - start = 22.77607062900279
CompletionUsage(completion_tokens=357, prompt_tokens=974, total_tokens=1331)
stop - start = 25.096272947001125
CompletionUsage(completion_tokens=392, prompt_tokens=1348, total_tokens=1740)
stop - start = 2.3558405980002135
CompletionUsage(completion_tokens=30, prompt_tokens=1757, total_tokens=1787)
stop - start = 2.8636473680016934
CompletionUsage(completion_tokens=37, prompt_tokens=1806, total_tokens=1843)

w. APC

stop - start = 28.40403065999999
CompletionUsage(completion_tokens=453, prompt_tokens=16, total_tokens=469)
stop - start = 13.463971014996787
CompletionUsage(completion_tokens=211, prompt_tokens=485, total_tokens=696)
stop - start = 15.43624263699894
CompletionUsage(completion_tokens=245, prompt_tokens=712, total_tokens=957)
stop - start = 22.343338724000205
CompletionUsage(completion_tokens=355, prompt_tokens=974, total_tokens=1329)
stop - start = 25.549687523998728
CompletionUsage(completion_tokens=403, prompt_tokens=1346, total_tokens=1749)
stop - start = 1.933658195001044
CompletionUsage(completion_tokens=30, prompt_tokens=1766, total_tokens=1796)
stop - start = 2.3811154130016803
CompletionUsage(completion_tokens=37, prompt_tokens=1815, total_tokens=1852)

It seems align to what you analyzed.

hmellor commented 6 months ago

I am not sure what GPU this is

My results came from A100 40GB with --gpu-memory-utilization 0.5 and --enforce-eager (both of which would make my experiments slower).

there are only 0.5s worth of time that can be optimized in this case

Ok, so it's simply a case of my test not being suitable. If I was running a model with a more expensive prefill (i.e. bigger than 7B) and with longer prompts, I'd start to be able to observe the difference in a single conversation (albeit a subtle difference).

Presumably there is also a concurrency benefit too, because the slot that would have been scheduled to execute the cached prefill can be used to process the prefill (or decoding) of a different request?

KuntaiDu commented 6 months ago

The key thing for automatic prefix caching to have a sizable improvement is that the ratio between input token length and output token length should be VERY VERY large (ideally more than 100x difference).

This is a very strong workload requirement, and such type of workload only commonly occurs in specific applications (e.g. asking questions to a very long software manual).

hmellor commented 6 months ago

I ran a better test and have an interesting graph:

image

Regardless of first prompt size, there seems to be a large fixed cost on turn 1 (i.e. the second turn), but not the subsequent turns.

@SageMoore:

robertgshaw2-neuralmagic commented 6 months ago

@hmellor - this is caused by Triton jitting. The first time the server runs the context_fwd_attention, Triton jits which slows us down. Have been meaning to finish off a PR that runs the JITing durin profiling, but has become lower priority since if you use latest main with flash attention this issue is resolved b/c it uses the flash attn kernels rather than triton for context_fwd_attn

robertgshaw2-neuralmagic commented 6 months ago

note: this will happen once per instantiation of the server

hmellor commented 6 months ago

if you use latest main with flash attention this issue is resolved

Is that the flash attention from the "pip install vllm-flash-attn for better performance." info log I've seen?

comaniac commented 6 months ago

Yes. You can just pip install vllm-flash-attn and make sure seeing the log Using FlashAttention-2 when launching the server.

robertgshaw2-neuralmagic commented 6 months ago

I think its now installed automatically

https://github.com/vllm-project/vllm/blob/main/setup.py#L356

hmellor commented 6 months ago

Ok, thanks for clearing that up for me!

whr819987540 commented 1 day ago

How should we enable APC in the offline chat?