flashinfer-ai / flashinfer

FlashInfer: Kernel Library for LLM Serving
https://flashinfer.ai
Apache License 2.0
1.22k stars 115 forks source link

[WIP] rafactor: make `gqa_group_size` a function argument instead of template parameter #262

Closed yzh119 closed 3 months ago

yzh119 commented 4 months ago

Our previous implementation treats gqa_group_size as a template parameter. We also have an implicit assumption that 8 should be divisible by gqa_group_size, this is not documented and cause lots of confusion for developers.

223 could be a workaround but this would quickly increase the binary size of this library if we keep adding more and more gqa group sizes.

This PR refactors the code so that gqa_group_size becomes a runtime function argument, and we don't need to compile one kernel for each of them. This would be bring very slight performance degradation to gqa_group_size=1 but should work good in general. This PR also removes the requirement that "8 should be divisible by gqa_group_size", so that any gqa group size should be supported.

This PR should resolve #142 #181 #246 #254 #258 and so on.

Still working in progress, estimated finish time: May 27th at noon (PST time). May 30th June 3rd.

yzh119 commented 3 months ago

Follow up in #301