Open imoneoi opened 1 year ago
It is already implemented here and is utilized by Mistral-7B (via xFormers)
Do we need to create a patch to use this?
FlashAttention supports sliding windows if you pass in the window_size
parameter. I've updated the README with this.
Just for clarification, the mask is applied on top of the sliding window, right? like it doesnt see on the right of the current token if I use the standard decoding mask. Also do I need to do something for when the sliding window is bigger than the length of the tokens before the current token, or there is an automatic cut-off?
@tridao I would like to clarify a little bit the options for sliding window in the context of Mistral 7B
Correct me if my understanding is incorrect:
At training, causal=True. We need the proper triangular + Band mask which will mask beyond i - sliding_window. So the call to flash_attn_func should be with window_size=(4096, 0) Is this correct ?
At inference, causal is False. In the case we pass k, v ALREADY trimmed to seqlen <= sliding_window by managing the kv cache up front, what should be the values for window_size ? (-1, -1) ? (0, 0) ?
if kv cache contains the full sequence (not trimmed ie not optimized) what should be window_size ?
Thanks.
@tridao I would like to clarify a little bit the options for sliding window in the context of Mistral 7B
The source of truth is the reference attention implementation that we test against.
Correct me if my understanding is incorrect:
At training, causal=True. We need the proper triangular + Band mask which will mask beyond i - sliding_window. So the call to flash_attn_func should be with window_size=(4096, 0) Is this correct ?
Yes window_size = (4096, 0).
At inference, causal is False. In the case we pass k, v ALREADY trimmed to seqlen <= sliding_window by managing the kv cache up front, what should be the values for window_size ? (-1, -1) ? (0, 0) ?
If seqlen_q = 1, and you already trim KV cache, then you don't need any masking (i.e. (-1, -1) window size, which is the default). You can also pass (4096, 0) if you want and it would yield the same answer.
if kv cache contains the full sequence (not trimmed ie not optimized) what should be window_size ?
(4096, 0)
Thanks.
Generally passing the same window size (4096, 0) should work for all cases. But ofc you should test.
ok thanks, so I think this implementation is wrong https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L514-L531 and I'm happy to understand it correctly :)
Does FlashAttention consider implementing sliding window attention (like used in Mistral)?
https://huggingface.co/mistralai/Mistral-7B-v0.1