flashinfer-ai / flashinfer

FlashInfer: Kernel Library for LLM Serving
https://flashinfer.ai
Apache License 2.0
768 stars 64 forks source link

feat: Separate Q and KV dtypes for decode #286

Closed Yard1 closed 3 weeks ago

Yard1 commented 1 month ago

Closes https://github.com/flashinfer-ai/flashinfer/issues/285

Modified unit tests pass. May need some extra validation.

Yard1 commented 1 month ago

@yzh119 Please let me know if this is on the right track! I couldn't see anything directly related to the dtype of the query in the kernels, so my assumption is this should "just work", but I don't know if this will not affect eg. q_vec loading. I am compiling it to test it right now.

yzh119 commented 1 month ago

Yes I do think you are on the right track, thank you!

but I don't know if this will not affect eg. q_vec loading.

I don't think so.

Yard1 commented 3 weeks ago

@yzh119 The modified unit test passes for me, can you review and validate?

Yard1 commented 3 weeks ago

@yzh119 correct, I wanted to avoid having to modify the public API. I don't think the information about the query dtype will be used in resource estimation, but please correct me if that's not the case - happy to do the change then

yzh119 commented 3 weeks ago

Hi @Yard1 , I'm a little bit conservative here because this section of code

https://github.com/flashinfer-ai/flashinfer/blob/1250b686869b514f796029f5d05fab114ec7d540/include/flashinfer/attention/handler.cuh#L121-L130

might produce different num_blocks_per_sm because of the difference of qtype in the kernel.

Yard1 commented 3 weeks ago

Ok sounds good! Let me make the change.

Yard1 commented 3 weeks ago

@yzh119 Updated, ptal!