Currently, in vLLM we do qkv projection as one matmul, resulting in a tensor that is then split into q, k and v. This splitting causes the tensors to be non-contiguous. It would be great if we could support this case (and avoid copies required to make the tensors contiguous) in FlashInfer's attention kernels by passing a stride parameter for each of the tensors, similarly to how Flash Attention does it. This concerns both Paged (just query) and Ragged (query, key, value) kernels.
Currently, in vLLM we do qkv projection as one matmul, resulting in a tensor that is then split into q, k and v. This splitting causes the tensors to be non-contiguous. It would be great if we could support this case (and avoid copies required to make the tensors contiguous) in FlashInfer's attention kernels by passing a stride parameter for each of the tensors, similarly to how Flash Attention does it. This concerns both Paged (just query) and Ragged (query, key, value) kernels.