vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
22.09k stars 3.12k forks source link

[Feature]: Tree attention about Speculative Decoding #3960

Open yukavio opened 2 months ago

yukavio commented 2 months ago

🚀 The feature, motivation and pitch

I want to implement tree attention for vllm mentioned in RoadMap. But I don’t know whether I should implement it based on paged-attention kernel implemented in vllm or FlashInfer due to I found we plan to replace this kernel in this PR.

Alternatives

No response

Additional context

No response

cadedaniel commented 2 months ago

Thanks for your interest in contributing! FYI tree attention is a bit complicated to implement with non-contiguous KV cache, since intra-block attention masking has not been implemented anywhere AFAIK. We can get around this by limiting vLLM to block size of 1, but this makes it difficult to optimize latency of verification as we limit the allowed vLLM configuration space.

The way I'd recommend going about this is to implement intra-block attention masking first, then integrate it with vLLM. This is the surefire way to obtain the best latency reduction possible in vLLM. The steps as follows:

After the remaining open sourcing work is complete, I'll add some documentation for this.

More background information here: https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit#heading=h.kk7dq05lc6q8

reyoung commented 2 months ago

@cadedaniel @yukavio

Tree attention mechanisms can also be utilized to generate multiple outcomes from the same prompt by varying the seeds.

This approach is an effective strategy to ensure the stability of results produced by Large Language Models (LLMs). For instance, when employing an LLM as a scoring tool to derive metrics, one could sample the LLM's outputs multiple times. By averaging these samples, a more reliable result can be obtained.


This feature might become available following the implementation of tree attention mechanisms.

yukavio commented 2 months ago

@cadedaniel Thanks for your reply. I have read your document and it seems that the key to the problem is that each token in the score phase requires a loop and calculation of the entire kv-cache. I think this problem can be solved by storing all pre-score tokens for a certain seq under the same adjacent address, instead of treating them as different seqs after expansion. In this way, we can perform calculations efficiently through tensor-core with a specific attention mask. But in this way, we should organize the pre-score token in one sequence (left in img) instead of multiple sequences (right in img).

image
If you think this way of organizing pre-score tokens is appropriate, I can implement the tensor-core cuda kernel with tree attention mask.
cadedaniel commented 2 months ago

@yukavio you should talk with @LiuXiaoxuanPKU , who is adding MQA scoring to vLLM