Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
13.87k stars 1.28k forks source link

Feature request: Sliding Window Attention #580

Open imoneoi opened 1 year ago

imoneoi commented 1 year ago

Does FlashAttention consider implementing sliding window attention (like used in Mistral)?

https://huggingface.co/mistralai/Mistral-7B-v0.1

david-macleod commented 1 year ago

It is already implemented here and is utilized by Mistral-7B (via xFormers)

MrigankRaman commented 1 year ago

Do we need to create a patch to use this?

tridao commented 1 year ago

FlashAttention supports sliding windows if you pass in the window_size parameter. I've updated the README with this.

artnoage commented 1 year ago

Just for clarification, the mask is applied on top of the sliding window, right? like it doesnt see on the right of the current token if I use the standard decoding mask. Also do I need to do something for when the sliding window is bigger than the length of the tokens before the current token, or there is an automatic cut-off?

vince62s commented 1 year ago

@tridao I would like to clarify a little bit the options for sliding window in the context of Mistral 7B

Correct me if my understanding is incorrect:

At training, causal=True. We need the proper triangular + Band mask which will mask beyond i - sliding_window. So the call to flash_attn_func should be with window_size=(4096, 0) Is this correct ?

At inference, causal is False. In the case we pass k, v ALREADY trimmed to seqlen <= sliding_window by managing the kv cache up front, what should be the values for window_size ? (-1, -1) ? (0, 0) ?

if kv cache contains the full sequence (not trimmed ie not optimized) what should be window_size ?

Thanks.

tridao commented 1 year ago

@tridao I would like to clarify a little bit the options for sliding window in the context of Mistral 7B

The source of truth is the reference attention implementation that we test against.

Correct me if my understanding is incorrect:

At training, causal=True. We need the proper triangular + Band mask which will mask beyond i - sliding_window. So the call to flash_attn_func should be with window_size=(4096, 0) Is this correct ?

Yes window_size = (4096, 0).

At inference, causal is False. In the case we pass k, v ALREADY trimmed to seqlen <= sliding_window by managing the kv cache up front, what should be the values for window_size ? (-1, -1) ? (0, 0) ?

If seqlen_q = 1, and you already trim KV cache, then you don't need any masking (i.e. (-1, -1) window size, which is the default). You can also pass (4096, 0) if you want and it would yield the same answer.

if kv cache contains the full sequence (not trimmed ie not optimized) what should be window_size ?

(4096, 0)

Thanks.

Generally passing the same window size (4096, 0) should work for all cases. But ofc you should test.

vince62s commented 1 year ago

ok thanks, so I think this implementation is wrong https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L514-L531 and I'm happy to understand it correctly :)