Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
13.28k stars 1.2k forks source link

Bug when window_size_right > max_seq_k and seq_q > seq_k ? #895

Open helloLLM666 opened 5 months ago

helloLLM666 commented 5 months ago

Not sure why we have to set window_size_right to -1 when it is larger than max_seq_k: https://github.com/Dao-AILab/flash-attention/blob/36587c01cb4390de0a590b2131e3fcc4859ba09c/csrc/flash_attn/flash_api.cpp#L1066

In the later code window_size_right will be set to max_seq_k: https://github.com/Dao-AILab/flash-attention/blob/36587c01cb4390de0a590b2131e3fcc4859ba09c/csrc/flash_attn/flash_api.cpp#L124

where "seqlen_k" is set to max_seqlen_k when calling set_params_xxx api: https://github.com/Dao-AILab/flash-attention/blob/36587c01cb4390de0a590b2131e3fcc4859ba09c/csrc/flash_attn/flash_api.cpp#L1160

As a result, when we have cases when seq_q > seq_k, the right bound of sliding window may mistakenly skip some seq_k-dim data:

  1. we defaultly use bottom-right alignment for central line of sliding window: https://github.com/Dao-AILab/flash-attention/blob/36587c01cb4390de0a590b2131e3fcc4859ba09c/flash_attn/flash_attn_interface.py#L1032

  2. when seq_q > seq_k, we need a window size of seq_k + (seq_q - seq_k) = seq_q to cover then whole seq_k-dim on the right

  3. Assume we pass a window_size_right of value INT32_MAX,aiming to use all seq_k on the right hand size. But if window_size_right is reset to max_seqlen_k,(seq_q - max_seqlen_k) number of data would be dropped

tridao commented 5 months ago

Yeah i think you might be right, it's a bug.