Open Amanda-Barbara opened 1 week ago
It seems the problem with the accuracy error is that flashinfer.single_prefill_with_kv_cache doesn't support cu_seqlens_q and cu_seqlens_k, If I want to use flashinfer's prefill function, How to call it like flash_attn_2_cuda.varlen_fwd?
single_prefill_with_kv_cache
is only designed for single request (no batching and variable length).
For batch prefill with variable length, you have to use https://docs.flashinfer.ai/api/python/prefill.html#flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper, where query and key/value cache are organized as ragged tensor. (see our layout documentation: https://docs.flashinfer.ai/tutorials/kv_layout.html).
If use you paged kv-cache, you should use https://docs.flashinfer.ai/api/python/prefill.html#flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper
@yzh119 thanks very much, I will try it.
There are precision errors compared with flash_attn_2_cuda.varlen_fwd when I use flashinfer.single_prefill_with_kv_cache function to run cohere_plus model, below is the code I used: fi_fwd_out = flashinfer.single_prefill_with_kv_cache(q.contiguous(), k.contiguous(), v.contiguous(), causal=True, sm_scale=softmax_scale, allow_fp16_qk_reduction=False) fa2_fwd_out = flash_attn_2_cuda.varlen_fwd( q, k, v, out, cu_seqlens, cu_seqlens, max_s, max_s, 0.0, softmax_scale, False, True, False, None, ) torch.allclose(fi_fwd_out, fa2_fwd_out, rtol=1e-3, atol=1e-3) It is worth noting that the first half of the layers are same, but second half are different. can you give an official example code for precision comparison with flash_attn_2_cuda.varlen_fwd? Thanks!