Open shermansiu opened 1 year ago
Is there any plan to have an integration with fastchat, vllm, huggingface TGI?
FastChat supports vLLM integration. TGI supports vLLM integration in versions ≥ 0.9 with supported models.
The current implementation of look-ahead decoding is built on top of Huggingface transformers, though a CUDA kernel is coming.
Obviously, the maximum benefit would be derived from integrating this into Huggingface Transformers directly, as this approach can work without paged attention. But the maintainers are already concerned with not overcomplicating the generate
function.
Just a followup question given that this approach can work without paged attention, can vllm and lookahead decoding used via fastchat for all the models which are part of fastchat and chatbot arena? Is there something in the roadmap?
Like I said earlier, FastChat supports vLLM integration. vLLM should work for all of the models listed here.
I'm trying to get this integrated into Huggingface Transformers.
I am also not one of the co-authors of Lookahead Decoding, for the sake of clarity.
Thank you for the clarification @shermansiu
@Viol2000 Will the upcoming CUDA kernel be compatible with Flash Attention? Flash Attention has its own CUDA kernels and I wouldn't want to have to choose between one or the other.
Hi @shermansiu , we will release a CUDA kernel based on FlashAttention very soon. Yes, it will be compatible with Flash Attention so you do not need choose between one or the other.
About the integration problem, we currently support huggingface/transformers. You can speedup transformers' native generate
function in a few LoCs as in our example code. We will make it more clear in the README. Note that we only support LLaMA model and greedy search yet.
We will also integrate Lookahead Decoding with vllm soon.
Thanks for your interest!
The current implementation works by monkey-patching some of Llama's methods, which isn't ideal, from a software development perspective. It would be better to add the changes in lade/models/llama.py to the original file itself.
If I have time, I'm willing to contribute this to Huggingface transformers
directly. This will allow lookahead decoding to be used by many more people and will require fewer package dependencies for the best LLM inference. huggingface/transformers#27649
Thank you for pointing out the drawback of our current implementation. We agree that integrating the changes directly into lade/models/llama.py would be a more robust approach.
Your willingness to contribute to the Huggingface transformers directly is greatly appreciated! It indeed sounds like a valuable enhancement, enabling broader use of lookahead decoding and reducing package dependencies for optimal LLM inference. We'll keep an eye on your progress and contribution on this. If you need any extra help please connect.
Thanks!
Currently, I'm trying to assess the best method to integrate this moving forward.
It might be best to integrate the CUDA kernel into https://github.com/Dao-AILab/flash-attention to ensure that Lookahead Decoding works with any new inference optimizations that Tri Dao comes up with. But then again, that would be subject to his discretion because maintaining the kernel would be extra work.
Also, n-gram caching might be out of scope for the flash-attn
package, but would be in-scope for 🤗transformers
. Tom Aarsen is currently working on a KV cache refactor for transformers
and Lookahead Decoding can be integrated after its completion.
Thank you for your coordination!
We're committed to ensuring our CUDA kernel is compatible with Tri Dao's updates in the flash-attention repo and will maintain it accordingly whether it will be integrated into flash-attention's official repo.
Also, we're keeping an eye on the KV cache refactor in transformers to integrate Lookahead Decoding post-refactor.
Also, as a slight implementation detail, it might be more efficient to store the n-grams as a trie. It should improve performance of the lookahead decoding somewhat.
Thanks for your constructive insight! Building a trie will reduce #token input but complicate the attention mask building. We are finding a way to achieve better performance.
Please correct me if I'm misunderstanding something, but I don't think using a trie would affect the attention mask at all? The trie just helps identify matching n-grams more quickly during verification. So instead of taking $O(G)$ time, it would take amortized $O(N)$ time to assess whether the cache contains the n-gram we're searching for.
(Then again, with such a low value of $G=W$ and $N$, any performance difference should be negligible in practice.)
Hello, I believe we might be discussing different aspects of the verification process. There are two components: the GPU side generates softmax values, while the CPU side verifies tokens based on these values. The time complexities O(G) and O(N) likely refer to the CPU side.
Regarding the use of a trie, it would necessitate a redesign of the GPU side verification branch, as illustrated in the right bottom part in figure 5 of the blog. The current system processes discrete n-grams, but with a trie, these n-grams would be interdependent, altering the attention mask. This approach could significantly reduce costs, especially when dealing with a large number of guess n-grams.
From the blog post (under the Verification Branch heading): "In the verification branch, we identify n-grams whose first token matches the last input token. This is determined via a simple string match. Once identified, these n-grams are appended to the current input and subjected to verification via an LLM forward pass through them."
Yes, I'm referring to the CPU side. Using a trie should speed up this string match. Once a matching n-gram is found, the current attention map scheme can be used to speed up decoding.
I see what you mean now. If we want to verify all matching k-grams where $2\le k \le n$, then I suppose the attention map would get more complicated. (Edit: Although, it might not be worthwhile to validate all of these candidate k-grams due to the added FLOPS requirements.)
I've added an initial implementation of the technique in llama.cpp
if anyone is interested in playing with it:
Are there any plans to integrate this into other open-source libraries, like vLLM or Huggingface Transformers?