Closed rchardx closed 3 days ago
For example, case (qo_len=2, kv_len=3) should have a causal mask as (kv_idx >= kv_len + q_idx - qo_len):
kv_idx >= kv_len + q_idx - qo_len
\begin{pmatrix} 0 & -\infty & -\infty\\ 0 & 0 & -\infty \end{pmatrix}
The causal mask by include/flashinfer/attention/prefill.cuh and src/cpu_reference.h (kv_idx > kv_len + q_idx - qo_len):
include/flashinfer/attention/prefill.cuh
src/cpu_reference.h
kv_idx > kv_len + q_idx - qo_len
\begin{pmatrix} 0 & 0 & -\infty\\ 0 & 0 & 0 \end{pmatrix}
For example, case (qo_len=2, kv_len=3) should have a causal mask as (
kv_idx >= kv_len + q_idx - qo_len
):The causal mask by
include/flashinfer/attention/prefill.cuh
andsrc/cpu_reference.h
(kv_idx > kv_len + q_idx - qo_len
):