Closed cyLi-Tiger closed 3 months ago
I recompiled with set(CMAKE_CUDA_ARCHITECTURES 80)
and it's working now, that's quite weird lmao.
Hi @cyLi-Tiger ,
Just trying to reproduce the problem (Check the figure attached). I suppose it may come from different impl of runtime APIs like cudaOccupancyMaxActiveBlocksPerMultiprocessor with different compute_capability. Therefore, the pre-kernel process (like this) may fail and cause assertion. Sorry for the issue as quest is only tested on sm89 GPUs :)
Thanks for your prompt reply!@happierpig
Another question, I'm new to flashinfer and wonder why the current kernel doesn't support GQA? Is it because flashinfer doesn't support GQA itself or GQA isn't suitable for quest currently?
Besides, can similar approach be applied to prefill stage and reduce TTFT?
Thanks for your great questions!
FlashInter does efficiently support various GQA setting (even MLA). However, currently Quest does not support it. Query heads within same group should attend to same set of kv-tokens in order to utilize tensor core for efficient GQA execution. We are conducting preliminary experiments to add support for GQA.
Regarding to prefill phase, it is similar to GQA. Since basic shape of tensor op is 16x8x16 (M dimension 16). It is better to align 2 continuous query token to attend same "critical" tokens, so that 2 x Group_size can saturate 16 (M dimension). Therefore, we can use some aggregate op in estimating process to implement this.
Since basic shape of tensor op is 16x8x16 (M dimension 16)
Might seem silly...But where does this come from, any reference?
You can check with official CUDA docs here. Since "mma.sync.m8n8k4 is optimized for target architecture sm_70 and may have substantially reduced performance on other target architectures", I suppose the effective minimal M dimension is 16.
All my questions are well answered, thanks!
Hi, thanks for the great work!
I got an error while benchmarking
batch_decode
's efficiency with following command.The logs are attached. I didn't change the code, any clues on solving that?