flashinfer-ai / flashinfer

FlashInfer: Kernel Library for LLM Serving
https://flashinfer.ai
Apache License 2.0
756 stars 62 forks source link

QUESTION: How to implement a tree attention with flashinfer #152

Open UranusSeven opened 4 months ago

UranusSeven commented 4 months ago

Hi, thanks for your awesome work!

I'm trying to implement https://github.com/SafeAILab/EAGLE with high-performance kernels. I read this blog and it says

FlashInfer implements prefill/append kernels for Paged KV-Cache which none of the existing libraries have done before, and it can be used to serve models in speculative decoding setting.

However, I was unable to locate arguments like position_id (utilized for rotary embedding) and attention_mask (for enforcing causality constraints).

Could you please provide an example of implementing a tree attention model using flashinfer? Any guidance you can offer would be greatly appreciated.

zhyncs commented 4 months ago

In order to support the feature of token position disruption brought by speculative decoding, two adjustments need to be made: one is the cos/sin matrix of RoPE, and the other is replacing casual mask with tree mask. With this, it will be very convenient to implement algorithms such as Medusa, EAGLE. From the document at https://docs.flashinfer.ai/index.html, it is currently not supported yet.

jpf888 commented 4 months ago

We also need to support MEDUSA when we use MLC-LLM again, and we have seen that Tensorrtllm supports MEDUSA

zhyncs commented 3 months ago

We also need to support MEDUSA when we use MLC-LLM again, and we have seen that Tensorrtllm supports MEDUSA

The current implementation of Medusa in TensorRT-LLM is not fully functional, nor is it a SOTA implementation. By the way, if Medusa is not implemented based on tree mask, you can directly add a verification module at the location of model output without modifying the kernel code in the project. However, performance will be slightly worse and there will be redundant validation.

UranusSeven commented 3 months ago

In order to support the feature of token position disruption brought by speculative decoding, two adjustments need to be made: one is the cos/sin matrix of RoPE, and the other is replacing casual mask with tree mask. With this, it will be very convenient to implement algorithms such as Medusa, EAGLE. From the document at https://docs.flashinfer.ai/index.html, it is currently not supported yet.

Agree. But by using BatchPrefillWithPagedKVCacheWrapper, we can kind of sidestep the whole attention mask thing by just turning one draft sequence into a batch.

yzh119 commented 1 month ago

@zhyncs @UranusSeven we supported custom attention mask in #266, more documentations are coming.

zhyncs commented 1 month ago

@zhyncs @UranusSeven we supported custom attention mask in #266, more documentations are coming.

Cheers!

chenzhuofu commented 4 weeks ago

Hi @yzh119 , thanks for your great contribution on this issue! I am willing to adopt flashinfer (w/ custom causal mask) in my current proj. However I got a small question: which value should I set in the custom_mask? I guess I should set -5e4 for masking and 0 for other positions. Am I right? :)

UranusSeven commented 4 weeks ago

@zhyncs @UranusSeven we supported custom attention mask in #266, more documentations are coming.

Thanks for your amazing work!

yzh119 commented 4 weeks ago

@chenzhuofu @UranusSeven Hi, thanks for your attention, yes I think setting -inf or -5e4 for masking and 0 for others is correct.

Some simple examples: https://docs.flashinfer.ai/generated/flashinfer.prefill.single_prefill_with_kv_cache.html#flashinfer.prefill.single_prefill_with_kv_cache

or check this test case for batch attention: https://github.com/flashinfer-ai/flashinfer/blob/7aadc0d54a5bc01ccfc75baf4d24d8a821a42b31/python/tests/test_batch_prefill_kernels.py#L317-L369 using triu attention mask (fill -inf for masking and 0 for others) is equivalent to setting causal=True.

Tomorrowdawn commented 1 week ago

Hi now flashinfer supports custom mask, which is a great work! But how about the positional embedding? I found #69 introducing q_position and kv_position in C++ kernels, but I didn't find a relevant python api(am i missing something?).