flashinfer-ai / flashinfer

FlashInfer: Kernel Library for LLM Serving
https://flashinfer.ai
Apache License 2.0
760 stars 64 forks source link

There are precision errors compared with flash_attn_2_cuda.varlen_fwd #335

Open Amanda-Barbara opened 1 week ago

Amanda-Barbara commented 1 week ago

There are precision errors compared with flash_attn_2_cuda.varlen_fwd when I use flashinfer.single_prefill_with_kv_cache function to run cohere_plus model, below is the code I used: fi_fwd_out = flashinfer.single_prefill_with_kv_cache(q.contiguous(), k.contiguous(), v.contiguous(), causal=True, sm_scale=softmax_scale, allow_fp16_qk_reduction=False) fa2_fwd_out = flash_attn_2_cuda.varlen_fwd( q, k, v, out, cu_seqlens, cu_seqlens, max_s, max_s, 0.0, softmax_scale, False, True, False, None, ) torch.allclose(fi_fwd_out, fa2_fwd_out, rtol=1e-3, atol=1e-3) It is worth noting that the first half of the layers are same, but second half are different. can you give an official example code for precision comparison with flash_attn_2_cuda.varlen_fwd? Thanks!

Amanda-Barbara commented 1 week ago

It seems the problem with the accuracy error is that flashinfer.single_prefill_with_kv_cache doesn't support cu_seqlens_q and cu_seqlens_k, If I want to use flashinfer's prefill function, How to call it like flash_attn_2_cuda.varlen_fwd?

yzh119 commented 1 week ago

single_prefill_with_kv_cache is only designed for single request (no batching and variable length).

For batch prefill with variable length, you have to use https://docs.flashinfer.ai/api/python/prefill.html#flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper, where query and key/value cache are organized as ragged tensor. (see our layout documentation: https://docs.flashinfer.ai/tutorials/kv_layout.html).

If use you paged kv-cache, you should use https://docs.flashinfer.ai/api/python/prefill.html#flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper

Amanda-Barbara commented 1 week ago

@yzh119 thanks very much, I will try it.