Open AlpinDale opened 11 months ago
This is interesting. But I think the paper is missing the latency numbers. While the memory bandwidth is theoretically reduced, the additionally steps in compute, without an optimized kernel, might actually slow down the inference. I'm curious to hear whether there are practical improvements before committing this features to vLLM.
After reading the paper a bit more, there seems to be a few points that may make it more difficult to integrate into vLLM. Mainly:
In the sample code I linked, the K
matrix is indexed in different axes so an implementation would load non-contiguous elements. The authors propose storingK
twice.
I'm not sure if this is a hit we'd be willing to take as that also increases KV cache usage by 50%.
One of the SparQ Attention authors here, thanks for your interest in our work! We have recently released an updated version of our paper which includes microbenchmark results (arxiv.org/abs/2312.04985). These results show that for large batch sizes and sequence lengths (the regime in which SparQ can provide the biggest improvements), we can attain >4x speedup on A10s. We are hoping these results address some concerns with regards to SparQ's practical improvements.
Based on the previous discussion in this thread, we're aware that the 50% memory overhead is a concern. Is this something that would limit the utility for cases of interest?
As a team, we are eager for SparQ Attention to be used by the wider ML community, and hence are keen to support any attempts to implement SparQ into libraries such as vLLM. We would therefore like to invite any questions you may have about the method, or if you have any ongoing concerns about deploying SparQ in practice.
adding @robertgshaw2-neuralmagic @WoosukKwon to hear their take on this
This issue has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this issue should remain open. Thank you!
A newly released paper, SparQ Attention: Bandwidth-Efficient LLM Inference, suggests a method for increasing the inference throughput of LLMs up to 8x by reducing the memory bandwidth requirements within the attention blocks through selective fetching of the cached history.
A sample implementation looks like this: