mlc-ai / mlc-llm

Universal LLM Deployment Engine with ML Compilation
https://llm.mlc.ai/
Apache License 2.0
18.64k stars 1.51k forks source link

Add Attention Sinks #1357

Closed kmn1024 closed 9 months ago

kmn1024 commented 9 months ago

🚀 Feature

Add Attention Sinks (https://arxiv.org/pdf/2309.17453.pdf, https://github.com/tomaarsen/attention_sinks/) to MLC.

Motivation

mlc_chat_cli gets noticeably slower as the conversation progresses. I tried this on Orange Pi 5, with two setups: as exactly described in https://blog.mlc.ai/2023/08/09/GPU-Accelerated-LLM-on-Orange-Pi, and then compiling my own StableLM 3B (q4f16_1, OpenCL). You can see from these screenshots that toks/sec gradually decreases with progress (I waited in between each generation to ensure it wasn't due to thermal throttling). Screenshot from 2023-11-30 15-52-48 Screenshot from 2023-11-30 16-26-43

Such slow down is unavoidable given the nature of Attention, but maybe we can reduce the latency hit without affecting decoding quality too much by using Attention Sinks (figure 1 from the paper). The default cache settings for most models is window attention with window size = sequence length; with Attention Sinks maybe we can use something smaller.

Alternatives

It seems there will always be a latency vs quality trade off for any type of cache, but perhaps Attention Sinks currently offers the best trade off.

Additional context

I would love to work on this, if I can get a mentor to point out which files should be changed for a tidy implementation!

CharlieFRuan commented 9 months ago

Thanks for the request! Attention sink is definitely something on our minds and will probably support it soon! cc @junrushao @davidpissarra

kmn1024 commented 9 months ago

Thanks for your quick response Charlie!

I would like to offer my help on this one, since it looks relatively beginner-friendly, given the example implementation in https://github.com/tomaarsen/attention_sinks. Would it be OK for me to try to work on this over the next few days, and send you a PR for one model (whilst lighting a path to enable it for all other models)?

If experts such as yourself, Junru, or David have extra cycles, may I suggest something much more complicated and impactful: https://github.com/FasterDecoding/Medusa. Medusa is probably one of the most scalable speculative-decoding implementations, since it doesn't require a separate draft model. The claimed gains in toks/sec is impressive. It seems to fit well with MLC's focus on universality and high performance.

CharlieFRuan commented 9 months ago

Thank you so much for offering help! We really appreciate it.

However, implementing attention sink may not be the most beginner-friendly task as we handle most kv cache logic in a lower-level stack called TVM. For instance, when we introduced sliding-window-attention, some work needed to be done in TVM: https://github.com/apache/tvm/pull/15963.

Regarding speculative decoding, it is definitely something we are considering as well. We are currently working on another front for model serving (in contrast with simple chatting), which will probably include speculative decoding.

kmn1024 commented 9 months ago

Thanks for the heads up! Looking forwards to an implementation of this and speculative decoding too.

kmn1024 commented 9 months ago

@CharlieFRuan @davidpissarra So I didn't heed Charlie's warning, and attempted to implement attention sink. Ended up pulling most of my hair out, but I have something that seems to work. The changes are in the two repos that Charlie pointed out: https://github.com/mlc-ai/relax/compare/mlc...kmn1024:relax_attention_sinks:main https://github.com/mlc-ai/mlc-llm/compare/main...kmn1024:llm_attention_sinks:main (please ignore conv_templates.cc)

I will try to get some /stats tomorrow. Is this something that I can send to you two for review? No worries at all if you prefer not, I can see that you guys are super busy =)

CharlieFRuan commented 9 months ago

@kmn1024 Wow this is really impressive, thank you for the hard work!

We are in the process of migrating from the relax_model folder to SLIM, essentially a new workflow for compiling models on the mlc-llm layer. We are still wrapping it up and making documentation for it.

Therefore, the changes in lm_support.cc, llm_chat.cc would not be affected; but those in relax_model and mlc_llm/core.py may need to be migrated later when the new workflow is up.

With that being said, once you are ready, feel free to open a PR for both the TVM side and the mlc-llm side (old workflow is fine), then @davidpissarra and/or I will help review. We can later pick the changes to the new workflow.

Really appreciate the contribution!

kmn1024 commented 9 months ago

Thanks Charlie! I'll begin sending PRs.

Here's a screenshot of the code in action, with added logs to show when the cache trimming happens: Screenshot from 2023-12-07 13-27-14

CharlieFRuan commented 9 months ago

@kmn1024 Looks great, thank you so much! We'll look at the PRs.

davidpissarra commented 9 months ago

Hi, @kmn1024! Regarding sinks, since most of the SW logic was already implemented, we were able to reuse the WindowOverride function from SWA to implement it (three-line change, see https://github.com/apache/tvm/pull/16240). As of now, mistral is the only architecture that supports SWA and sinks (#1435). Part of our effort now is to bring sinks to the other models.

kmn1024 commented 9 months ago

Thanks, really appreciate your work! I will close out my PRs.

I want to take this opportunity to ask you (@davidpissarra) and Charlie (@CharlieFRuan) if there's more information on speculative decoding. Even with shortened context window, the steady state throughput is still too low, and extrapolating from Medusa numbers (my board has smaller warp size), I think something similar would give me a much needed ~1.3x boost.

I would be happy to sponsor a bounty, but 1. the TVM/MLC group seems better funded than myself; 2. I'm not sure if bounties are a source of motivation in academia. Any how, laying this out there, please let me know =)

CharlieFRuan commented 9 months ago

Thanks for proposing it! Unfortunately, there isn't too much update in speculative decoding as the team is occupied on various other things (e.g. serving, multi GPU, etc.). We would probably look into that after https://github.com/mlc-ai/mlc-llm/tree/serving lands.

junrushao commented 9 months ago

Si-ze and I are on speculative decoding on the serving branch

kmn1024 commented 9 months ago

Thanks for the replies! @junrushao do you plan on implementing it "bring your own draft model" style, or integrated like Medusa, or something else?

jpf888 commented 5 months ago

@junrushao hi, When will medusa be supported on the serve branch? Do you have any plans?