vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
26.05k stars 3.82k forks source link

[RFC]: Implement disaggregated prefilling via KV cache transfer #5557

Open KuntaiDu opened 2 months ago

KuntaiDu commented 2 months ago

Motivation.

There are more and more use cases, where we need to transfer KV caches between vLLM instances, or store KV caches for future use. Some concrete use cases:

Proposed Change.

My current thought is to introduce two new abstractions: communicator and KV database. The workflow will be

vllm <--> communicator <--> KV database

where

This will be a huge framework, with a wide range of challenging (but fun!) questions inside, including but not limited to:

Feel free to post any thoughts on the design! Is it good? Is this abstraction able to achieve the optimal performance in your use cases?

Feedback Period.

Several weeks

CC List.

@simon-mo @youkaichao @zhuohan123 @cadedaniel @ywang96 @WoosukKwon @LiuXiaoxuanPKU

Any Other Things.

No response

KuntaiDu commented 2 months ago

After discussing, maybe it is better for us to focus on disaggregated prefilling first, and then it is much easier to tell how should we make the high-level architecture change.

For disaggregated prefilling, does the following workflow sound good or not?

For an upcoming request:

leiwen83 commented 2 months ago

Sounds very interesting!

For the second usage, I have a question

The user want to query a fixed set of long documents (examples: software manual, internal documents, etc). In this case, the GPU memory + CPU memory may not be enough to store the KV cache of all documents, and we may want to storage the KV cache of these documents and move them to GPU on-demand.

It seems to leverage the prefill caching mechanism, which require the doc is in the top of the prompt, and only the query part is different in the bottom, right? So that it could handle the case that long documents pieces along with many different query, and those top same part's kvcache would be stored inside CPU's memory?

And it's better also take consideration those GPU without nvlink like 4090...

For KV compression, I think maybe KV cache quanatization to 4/2bits would make this whole subsystem more valuable

richardliaw commented 2 months ago

Would it make sense to first get some simple design on abstractions for handling the KV cache, before designing the transport?

For example, having something like:

input_state = engine.prefill(input)
save(input_state, file)
----
input_state = read(file)
engine = engine.insert_state(input_state)
engine.generate(...)

Would be a nice starting point.

Then later maybe it can be async/lazy so that we would pipeline the state automatically

cadedaniel commented 2 months ago

I gave a comment offline, pasting it here:

The concept makes sense in vLLM but I am concerned we are starting with the infra first instead of the impactful feature or performance optimization. What usually happens is because the infra is built without a narrow use-case in mind, it is very difficult to prioritize design choices and infra features. Can we flip this on its head and instead build one of the user-impacting features/performance improvements, and work backwards from that to the infra features necessary? My thoughts are that prefill disagg has really tight performance constraints for KV transfer. it would be a big waste if the eventual implementation couldn’t use this work because the performance requirements weren’t known ahead-of-time.

AnikinNN commented 2 months ago

I have found one more usage for storing KV cache somewhere. I suppose it would be nice to have this feature when working with agents such as chain of thoughts. It has a repetable phases of generation and appending tool's outputs. As for now, every time generation stops due to tool invocation and appending tool's outputs to the prompt, LLM then calls again. We have growing leading part of prompt which is the same inside one call of chain-of-thoughts.

KuntaiDu commented 2 months ago

Sounds very interesting!

For the second usage, I have a question

The user want to query a fixed set of long documents (examples: software manual, internal documents, etc). In this case, the GPU memory + CPU memory may not be enough to store the KV cache of all documents, and we may want to storage the KV cache of these documents and move them to GPU on-demand.

It seems to leverage the prefill caching mechanism, which require the doc is in the top of the prompt, and only the query part is different in the bottom, right? So that it could handle the case that long documents pieces along with many different query, and those top same part's kvcache would be stored inside CPU's memory?

And it's better also take consideration those GPU without nvlink like 4090...

For KV compression, I think maybe KV cache quanatization to 4/2bits would make this whole subsystem more valuable

In the long document reusing case, sure CPU can be used as a layer of cache. But there are two scenarios, where using CPU as a KV cache is NOT efficient:

For those devices without NVLink, I agree with you, it would be nice if we can support it. But let's focus on make the KV transfer REALLY fast using NVLink first (which is a cool feature that trt/tgi/lmdeploy does not have), so that we can gauge more interest from other developers.

For KV compression, there is a series of research that explores alternative opportunities besides simple quantization. Some pointers: https://arxiv.org/abs/2306.14048 (token filtering) https://arxiv.org/pdf/2310.07240 (leveraging similarity between consecutive tokens for compression) So a lot of exciting opportunities besides simple quantization.

KuntaiDu commented 2 months ago

Would it make sense to first get some simple design on abstractions for handling the KV cache, before designing the transport?

For example, having something like:

input_state = engine.prefill(input)
save(input_state, file)
----
input_state = read(file)
engine = engine.insert_state(input_state)
engine.generate(...)

Would be a nice starting point.

Then later maybe it can be async/lazy so that we would pipeline the state automatically

Agree!!! A nuance here is what should be the granularity of KV cache read/write. Per vllm block or per query. My current preference is per vllm block, as the time when we need to read/save KV cache is typically tied to the decisions of block manager (e.g. we may need to read KV cache, when block manager allocates new block; or we may need to write KV cache to disk, when a KV cache is swapped out from CPU by block manager), so it is better to align the granularity with the block manager.

Jeffwan commented 2 months ago

Great to see the proposal! We are doing experiments to offload reusable KV contents to external cache store. Happy to discuss more details.

KuntaiDu commented 2 months ago

My current plan is to focus on implementing disaggregated prefilling using cross-vllm-instance kv cache transfer. Two reasons:

KuntaiDu commented 2 months ago

Base implementation: 4 processes: prefilling instance, decoding instance For a new incoming request:

Foreseeable overheads (compared to an implementation):

My very first step: measure the overhead of call the prefilling function again with the KV cache.

TopIdiot commented 1 month ago

Sounds great! And I think a scheduler is needed, to decide which two instances the request should be scheduled to.

leo6022 commented 1 month ago

How to implement kv-cache transfer, nccl or rdma?