flashinfer-ai / flashinfer

FlashInfer: Kernel Library for LLM Serving
https://flashinfer.ai
Apache License 2.0
760 stars 64 forks source link

Fix a bug related to causal mask #348

Closed rchardx closed 3 days ago

rchardx commented 3 days ago

For example, case (qo_len=2, kv_len=3) should have a causal mask as (kv_idx >= kv_len + q_idx - qo_len):

\begin{pmatrix}
0 & -\infty & -\infty\\ 
0 & 0 & -\infty
\end{pmatrix}

The causal mask by include/flashinfer/attention/prefill.cuh and src/cpu_reference.h (kv_idx > kv_len + q_idx - qo_len):

\begin{pmatrix}
0 & 0 & -\infty\\ 
0 & 0 & 0
\end{pmatrix}