Open samvanstroud opened 8 months ago
It's not supported. You can change the logic of local masking (e.g. search for "Local" in the code). You'd also need to be careful when deciding which blocks will be empty if you change this stride.
Thanks @tridao. Can you be any more specific about which changes would be needed? I guess as a start we'd need to add a multiplicative factor to the row_idx
in the two lines here https://github.com/Dao-AILab/flash-attention/blob/5aca153d6d4c2f87840480b54330c923f6b36ec4/csrc/flash_attn/src/mask.h#L52-L53
Is the logic about empty blocks somewhere else?
Thanks @timlacroix for adding local attention support. Do you know if it would be straightforward to add a "stride" to the local window attention to better support the cross attention case where there are significantly more keys that queries? In standard local attention this leads to some keys being unattended, as in
Instead I'd like to be able to shift the window by multiple keys for each new query, as in the case below where stride = 2