Open shermansiu opened 11 months ago
FYI @gante so we keep track of this. Shared offline but might be good after the cache refactoring.
The current reference implementation builds directly on top of Huggingface transformers, but the authors have mentioned that they plan to release a custom CUDA kernel to speed up the method.
Should we wait for this kernel? (My opinion: No, we shouldn't wait. Plus, I'm skeptical about whether such a kernel would be compatible with Flash Attention's own CUDA kernel, but we'll see.)
Cache refactoring PR: #26681
While we're waiting for the KV cache refactor to be completed, I think it might be worth considering how exactly to manage the Lookahead Decoding configuration, especially since there are a few associated parameters with it (e.g. the lookahead window size, the N-gram size).
I suppose it would be better to introduce a LookaheadDecoderConfig dataclass for this?
No I think these can just be passed in the generation config.
Hi @shermansiu 👋
Before commenting here, I've spent some time playing with lookahead decoding. In particular, using a modified version of their minimal.py
, so I could benchmark against datasets. I'm pasting an example in the collapsible below:
Here are some findings: 👉 As mentioned in the blog post, you are increasing FLOPS to get additional LLM throughput. All is good if the model is small for your device, but it's hard to achieve speedups using modest models on consumer GPUs (e.g. 7B models in a 3090) 👉 After some fiddling with the LADE parameters, I was able to get a 25% speedup on a 7B model in a 3090, compared to the model without FA2. Running with their default parameterization actually slows the model down by 33%, despite achieving a high compression ratio (= FLOPS is the bottleneck) 👉 Doesn't work correctly with FA2: the output is significantly different 👉 Works with BNB, but I didn't manage to get a speedup on my setup, only slowdowns 👉 Works with AWQ, same findings as in the case without quantization
On top of that, from the blog post we know that:
👉 It requires changes in the modeling code of each model, so it will require a lot of work to add and to maintain
👉 It is limited to greedy decoding, meaning that it doesn't support the most common use case (do_sample=True
)
👉 Batching with this technique is much trickier -- just like in speculative decoding/assisted generation, we may have more than one accepted token per forward pass
The idea does look very promising -- it would be amazing to be able to speed up a model without relying on external models. However, the current benefits are limited to GPU-rich users using a GPU oversized for the task at hand, and the addition costs are heavy, especially with model-level changes. The original code is also open-source and transformers
-compatible, despite being limited to llama
.
If a model-independent solution can be achieved, more positive benchmarks are found, or upgrades to the technique are released, I'd be happy to reconsider this decision!
Let's keep this issue open for discussion 🤗
^ Some of the acronyms in the above response: LADE = Lookahead decoding FA2 = Flash Attention 2 BNB: Bitsandbytes AWQ: Activation-aware Weight Quantization.
The authors mentioned that they are working on an FA2-compatible CUDA kernel, so hopefully we'll see better results soon!
BTW, here's a PR where we are looking at adding sampling support.
Feature request
Fu et al. propose a novel decoding technique that accelerates greedy decoding on Llama 2 and Code-Llama by 1.5-2x across various parameters sizes, without a draft model. This method can be extended to work on beam search decoding.
Blog post: https://lmsys.org/blog/2023-11-21-lookahead-decoding/ Code: https://github.com/hao-ai-lab/LookaheadDecoding
Motivation
Lookahead decoding provides a massive speedup at a worthwhile tradeoff (namely, a windowed n-gram cache and a custom attention mask). There have been other proposals to integrate lookahead decoding in other libraries like TGI or vLLM, but it seems that for this specific feature, it would be best integrated into the core
transformers
library the same way that Flash Attention has.Your contribution
I'm busy with thesis work, but I can submit a PR based on the original implementation here if I have time.