Open jon-chuang opened 3 months ago
Actually, @comaniac, I noticed that there are explicit asserts forbidding use of flash infer kernels for chunked prefill https://github.com/vllm-project/vllm/blob/774cd1d3bf7890c6abae6c7ace798c4a376b2b20/vllm/attention/backends/flashinfer.py#L195
As pointed out in: https://github.com/flashinfer-ai/flashinfer/issues/392#issuecomment-2246997216
My understanding is that this is because vLLM runs prefill and decode in two separate kernel invocations by default (as is the case for flash-attention, see: https://github.com/vllm-project/vllm/pull/6052), and this applies to flash-infer as well?
Perhaps the first step is to unify the flash infer kernels to use a single kernel, similar to https://github.com/vllm-project/vllm/pull/6052, or at least clarify in what scenario it is ok to run flash-infer kernels for chunked prefill, because according to @yzh119 in https://github.com/flashinfer-ai/flashinfer/issues/392, this should be supported by flash-infer already.
Anw, please assign it to me, I will investigate further
We are already working on this cc @Yard1
@comaniac Any updates or open PRs on this that we can take a look at?
🚀 The feature, motivation and pitch
From new Flash Infer Release https://github.com/flashinfer-ai/flashinfer/releases/tag/v0.1.4
cc @comaniac
Additional context
Follow up to: https://github.com/vllm-project/vllm/pull/7208, https://github.com/vllm-project/vllm/pull/7185