Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
13.62k stars 1.25k forks source link

Strided local attention #764

Open samvanstroud opened 8 months ago

samvanstroud commented 8 months ago

Thanks @timlacroix for adding local attention support. Do you know if it would be straightforward to add a "stride" to the local window attention to better support the cross attention case where there are significantly more keys that queries? In standard local attention this leads to some keys being unattended, as in

Instead I'd like to be able to shift the window by multiple keys for each new query, as in the case below where stride = 2

tridao commented 8 months ago

It's not supported. You can change the logic of local masking (e.g. search for "Local" in the code). You'd also need to be careful when deciding which blocks will be empty if you change this stride.

samvanstroud commented 8 months ago

Thanks @tridao. Can you be any more specific about which changes would be needed? I guess as a start we'd need to add a multiplicative factor to the row_idx in the two lines here https://github.com/Dao-AILab/flash-attention/blob/5aca153d6d4c2f87840480b54330c923f6b36ec4/csrc/flash_attn/src/mask.h#L52-L53

Is the logic about empty blocks somewhere else?