Closed kmn1024 closed 9 months ago
Thanks for the request! Attention sink is definitely something on our minds and will probably support it soon! cc @junrushao @davidpissarra
Thanks for your quick response Charlie!
I would like to offer my help on this one, since it looks relatively beginner-friendly, given the example implementation in https://github.com/tomaarsen/attention_sinks. Would it be OK for me to try to work on this over the next few days, and send you a PR for one model (whilst lighting a path to enable it for all other models)?
If experts such as yourself, Junru, or David have extra cycles, may I suggest something much more complicated and impactful: https://github.com/FasterDecoding/Medusa. Medusa is probably one of the most scalable speculative-decoding implementations, since it doesn't require a separate draft model. The claimed gains in toks/sec is impressive. It seems to fit well with MLC's focus on universality and high performance.
Thank you so much for offering help! We really appreciate it.
However, implementing attention sink may not be the most beginner-friendly task as we handle most kv cache logic in a lower-level stack called TVM. For instance, when we introduced sliding-window-attention, some work needed to be done in TVM: https://github.com/apache/tvm/pull/15963.
Regarding speculative decoding, it is definitely something we are considering as well. We are currently working on another front for model serving (in contrast with simple chatting), which will probably include speculative decoding.
Thanks for the heads up! Looking forwards to an implementation of this and speculative decoding too.
@CharlieFRuan @davidpissarra So I didn't heed Charlie's warning, and attempted to implement attention sink. Ended up pulling most of my hair out, but I have something that seems to work. The changes are in the two repos that Charlie pointed out: https://github.com/mlc-ai/relax/compare/mlc...kmn1024:relax_attention_sinks:main https://github.com/mlc-ai/mlc-llm/compare/main...kmn1024:llm_attention_sinks:main (please ignore conv_templates.cc)
I will try to get some /stats tomorrow. Is this something that I can send to you two for review? No worries at all if you prefer not, I can see that you guys are super busy =)
@kmn1024 Wow this is really impressive, thank you for the hard work!
We are in the process of migrating from the relax_model
folder to SLIM, essentially a new workflow for compiling models on the mlc-llm layer. We are still wrapping it up and making documentation for it.
Therefore, the changes in lm_support.cc
, llm_chat.cc
would not be affected; but those in relax_model
and mlc_llm/core.py
may need to be migrated later when the new workflow is up.
With that being said, once you are ready, feel free to open a PR for both the TVM side and the mlc-llm side (old workflow is fine), then @davidpissarra and/or I will help review. We can later pick the changes to the new workflow.
Really appreciate the contribution!
Thanks Charlie! I'll begin sending PRs.
Here's a screenshot of the code in action, with added logs to show when the cache trimming happens:
@kmn1024 Looks great, thank you so much! We'll look at the PRs.
Hi, @kmn1024! Regarding sinks, since most of the SW logic was already implemented, we were able to reuse the WindowOverride
function from SWA to implement it (three-line change, see https://github.com/apache/tvm/pull/16240). As of now, mistral is the only architecture that supports SWA and sinks (#1435). Part of our effort now is to bring sinks to the other models.
Thanks, really appreciate your work! I will close out my PRs.
I want to take this opportunity to ask you (@davidpissarra) and Charlie (@CharlieFRuan) if there's more information on speculative decoding. Even with shortened context window, the steady state throughput is still too low, and extrapolating from Medusa numbers (my board has smaller warp size), I think something similar would give me a much needed ~1.3x boost.
I would be happy to sponsor a bounty, but 1. the TVM/MLC group seems better funded than myself; 2. I'm not sure if bounties are a source of motivation in academia. Any how, laying this out there, please let me know =)
Thanks for proposing it! Unfortunately, there isn't too much update in speculative decoding as the team is occupied on various other things (e.g. serving, multi GPU, etc.). We would probably look into that after https://github.com/mlc-ai/mlc-llm/tree/serving lands.
Si-ze and I are on speculative decoding on the serving branch
Thanks for the replies! @junrushao do you plan on implementing it "bring your own draft model" style, or integrated like Medusa, or something else?
@junrushao hi, When will medusa be supported on the serve branch? Do you have any plans?
🚀 Feature
Add Attention Sinks (https://arxiv.org/pdf/2309.17453.pdf, https://github.com/tomaarsen/attention_sinks/) to MLC.
Motivation
mlc_chat_cli gets noticeably slower as the conversation progresses. I tried this on Orange Pi 5, with two setups: as exactly described in https://blog.mlc.ai/2023/08/09/GPU-Accelerated-LLM-on-Orange-Pi, and then compiling my own StableLM 3B (q4f16_1, OpenCL). You can see from these screenshots that toks/sec gradually decreases with progress (I waited in between each generation to ensure it wasn't due to thermal throttling).
Such slow down is unavoidable given the nature of Attention, but maybe we can reduce the latency hit without affecting decoding quality too much by using Attention Sinks (figure 1 from the paper). The default cache settings for most models is window attention with window size = sequence length; with Attention Sinks maybe we can use something smaller.
Alternatives
It seems there will always be a latency vs quality trade off for any type of cache, but perhaps Attention Sinks currently offers the best trade off.
Additional context
I would love to work on this, if I can get a mentor to point out which files should be changed for a tidy implementation!