vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
29.89k stars 4.51k forks source link

[FEATURE REQUEST] SparQ Attention #2039

Open AlpinDale opened 11 months ago

AlpinDale commented 11 months ago

A newly released paper, SparQ Attention: Bandwidth-Efficient LLM Inference, suggests a method for increasing the inference throughput of LLMs up to 8x by reducing the memory bandwidth requirements within the attention blocks through selective fetching of the cached history.

A sample implementation looks like this:

from torch import abs, softmax, sqrt, tensor, topk

def gather(t, dim, i):
    dim += (dim < 0) * t.ndim
    return t.gather(dim, i.expand(*t.shape[:dim], i.shape[dim], *t.shape[dim + 1 :]))

def attn(Q, K, V, M):
    s = (Q @ K.transpose(-1, -2)) / sqrt(tensor(Q.shape[-1])) + M
    return softmax(s, dim=-) @ V

def sparq_attn(Q, K, V, V_mean, M, r, k):
    # Approximate attention scores using r largest components of Q
    i1 = topk(abs(Q), r, -1).indices
    Q_hat, K_hat = gather(Q, -1, i1), gather(K, -1, i1)
    scale = sqrt(
        Q.shape[-1]
        * abs(Q_hat).sum(dim=-1, keepdim=True)
        / abs(Q).sum(dim=-1, keepdim=True)
    )
    s_hat = softmax(Q_hat @ K_hat.transpose(-1, -2) / scale + M, dim=-1)

    # Gather top_k positions based on approximate attention scores and run attention
    i2 = topk(s_hat, k, -1).indices
    iKV = i2[..., 0, :, None]
    K, V, M = gather(K, -2, iKV), gather(V, -2, iKV), gather(M, -1, i2)
    y_ = attn(Q, K, V, M)

    # Estimate the total score of the top_k, and interpolate with V_mean
    alpha = gather(s_hat, -1, i2).sum(-1, keepdim=True)
    return alpha * y_ + (1 - alpha) * V_mean
simon-mo commented 11 months ago

This is interesting. But I think the paper is missing the latency numbers. While the memory bandwidth is theoretically reduced, the additionally steps in compute, without an optimized kernel, might actually slow down the inference. I'm curious to hear whether there are practical improvements before committing this features to vLLM.

AlpinDale commented 11 months ago

After reading the paper a bit more, there seems to be a few points that may make it more difficult to integrate into vLLM. Mainly:

In the sample code I linked, the K matrix is indexed in different axes so an implementation would load non-contiguous elements. The authors propose storingK twice.

I'm not sure if this is a hit we'd be willing to take as that also increases KV cache usage by 50%.

hudlass commented 8 months ago

One of the SparQ Attention authors here, thanks for your interest in our work! We have recently released an updated version of our paper which includes microbenchmark results (arxiv.org/abs/2312.04985). These results show that for large batch sizes and sequence lengths (the regime in which SparQ can provide the biggest improvements), we can attain >4x speedup on A10s. We are hoping these results address some concerns with regards to SparQ's practical improvements.

Based on the previous discussion in this thread, we're aware that the 50% memory overhead is a concern. Is this something that would limit the utility for cases of interest?

As a team, we are eager for SparQ Attention to be used by the wider ML community, and hence are keen to support any attempts to implement SparQ into libraries such as vLLM. We would therefore like to invite any questions you may have about the method, or if you have any ongoing concerns about deploying SparQ in practice.

simon-mo commented 8 months ago

adding @robertgshaw2-neuralmagic @WoosukKwon to hear their take on this

github-actions[bot] commented 1 week ago

This issue has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this issue should remain open. Thank you!