hao-ai-lab / LookaheadDecoding

Apache License 2.0
1.06k stars 62 forks source link

Integration with other open-source libraries #1

Open shermansiu opened 8 months ago

shermansiu commented 8 months ago

Are there any plans to integrate this into other open-source libraries, like vLLM or Huggingface Transformers?

jalajthanakicactus commented 8 months ago

Is there any plan to have an integration with fastchat, vllm, huggingface TGI?

shermansiu commented 8 months ago

FastChat supports vLLM integration. TGI supports vLLM integration in versions ≥ 0.9 with supported models.

The current implementation of look-ahead decoding is built on top of Huggingface transformers, though a CUDA kernel is coming.

Obviously, the maximum benefit would be derived from integrating this into Huggingface Transformers directly, as this approach can work without paged attention. But the maintainers are already concerned with not overcomplicating the generate function.

jalajthanakicactus commented 8 months ago

Just a followup question given that this approach can work without paged attention, can vllm and lookahead decoding used via fastchat for all the models which are part of fastchat and chatbot arena? Is there something in the roadmap?

shermansiu commented 8 months ago

Like I said earlier, FastChat supports vLLM integration. vLLM should work for all of the models listed here.

I'm trying to get this integrated into Huggingface Transformers.

I am also not one of the co-authors of Lookahead Decoding, for the sake of clarity.

jalajthanakicactus commented 8 months ago

Thank you for the clarification @shermansiu

shermansiu commented 8 months ago

@Viol2000 Will the upcoming CUDA kernel be compatible with Flash Attention? Flash Attention has its own CUDA kernels and I wouldn't want to have to choose between one or the other.

Viol2000 commented 8 months ago

Hi @shermansiu , we will release a CUDA kernel based on FlashAttention very soon. Yes, it will be compatible with Flash Attention so you do not need choose between one or the other.

Viol2000 commented 8 months ago

About the integration problem, we currently support huggingface/transformers. You can speedup transformers' native generate function in a few LoCs as in our example code. We will make it more clear in the README. Note that we only support LLaMA model and greedy search yet. We will also integrate Lookahead Decoding with vllm soon. Thanks for your interest!

shermansiu commented 8 months ago

The current implementation works by monkey-patching some of Llama's methods, which isn't ideal, from a software development perspective. It would be better to add the changes in lade/models/llama.py to the original file itself.

If I have time, I'm willing to contribute this to Huggingface transformers directly. This will allow lookahead decoding to be used by many more people and will require fewer package dependencies for the best LLM inference. huggingface/transformers#27649

Viol2000 commented 8 months ago

Thank you for pointing out the drawback of our current implementation. We agree that integrating the changes directly into lade/models/llama.py would be a more robust approach.

Your willingness to contribute to the Huggingface transformers directly is greatly appreciated! It indeed sounds like a valuable enhancement, enabling broader use of lookahead decoding and reducing package dependencies for optimal LLM inference. We'll keep an eye on your progress and contribution on this. If you need any extra help please connect.

shermansiu commented 8 months ago

Thanks!

Currently, I'm trying to assess the best method to integrate this moving forward.

It might be best to integrate the CUDA kernel into https://github.com/Dao-AILab/flash-attention to ensure that Lookahead Decoding works with any new inference optimizations that Tri Dao comes up with. But then again, that would be subject to his discretion because maintaining the kernel would be extra work.

Also, n-gram caching might be out of scope for the flash-attn package, but would be in-scope for 🤗transformers. Tom Aarsen is currently working on a KV cache refactor for transformers and Lookahead Decoding can be integrated after its completion.

Viol2000 commented 8 months ago

Thank you for your coordination!

We're committed to ensuring our CUDA kernel is compatible with Tri Dao's updates in the flash-attention repo and will maintain it accordingly whether it will be integrated into flash-attention's official repo.

Also, we're keeping an eye on the KV cache refactor in transformers to integrate Lookahead Decoding post-refactor.

shermansiu commented 8 months ago

Also, as a slight implementation detail, it might be more efficient to store the n-grams as a trie. It should improve performance of the lookahead decoding somewhat.

Viol2000 commented 8 months ago

Thanks for your constructive insight! Building a trie will reduce #token input but complicate the attention mask building. We are finding a way to achieve better performance.

shermansiu commented 8 months ago

Please correct me if I'm misunderstanding something, but I don't think using a trie would affect the attention mask at all? The trie just helps identify matching n-grams more quickly during verification. So instead of taking $O(G)$ time, it would take amortized $O(N)$ time to assess whether the cache contains the n-gram we're searching for.

shermansiu commented 8 months ago

(Then again, with such a low value of $G=W$ and $N$, any performance difference should be negligible in practice.)

Viol2000 commented 8 months ago

Hello, I believe we might be discussing different aspects of the verification process. There are two components: the GPU side generates softmax values, while the CPU side verifies tokens based on these values. The time complexities O(G) and O(N) likely refer to the CPU side.

Regarding the use of a trie, it would necessitate a redesign of the GPU side verification branch, as illustrated in the right bottom part in figure 5 of the blog. The current system processes discrete n-grams, but with a trie, these n-grams would be interdependent, altering the attention mask. This approach could significantly reduce costs, especially when dealing with a large number of guess n-grams.

shermansiu commented 8 months ago

From the blog post (under the Verification Branch heading): "In the verification branch, we identify n-grams whose first token matches the last input token. This is determined via a simple string match. Once identified, these n-grams are appended to the current input and subjected to verification via an LLM forward pass through them."

Yes, I'm referring to the CPU side. Using a trie should speed up this string match. Once a matching n-gram is found, the current attention map scheme can be used to speed up decoding.

I see what you mean now. If we want to verify all matching k-grams where $2\le k \le n$, then I suppose the attention map would get more complicated. (Edit: Although, it might not be worthwhile to validate all of these candidate k-grams due to the added FLOPS requirements.)

ggerganov commented 8 months ago

I've added an initial implementation of the technique in llama.cpp if anyone is interested in playing with it:

https://github.com/ggerganov/llama.cpp/pull/4207