tomaarsen / attention_sinks

Extend existing LLMs way beyond the original training length with constant memory usage, without retraining
https://huggingface.co/blog/tomaarsen/attention-sinks
Apache License 2.0
650 stars 41 forks source link

Issue with only adding sink tokens in cache #17

Open sam1373 opened 9 months ago

sam1373 commented 9 months ago

It seems that in this implementation you are only adding the "sink" token to the cache, and not using in the original forward pass, so if you are using windowed attention and your input token sequence is longer than the window size, than the later tokens in this pass will not attend to any "sink" tokens. In your tests I think this issue would mostly be avoided because the original prompt is either short before you start generating tokens, or in the case of your streaming test you end up doing a new forward pass separately for each added sub-prompt, which are relatively short. However it's unclear from this how well it would work if you wanted to do a single forward pass on a long text. Do you have any tests showing that this still works with this implementation?

tomaarsen commented 9 months ago

Hello!

I'm having a bit of a hard time trying to understand what you're saying. Let me try to break it down a bit.

It seems that in this implementation you are only adding the "sink" token to the cache, and not using in the original forward pass, so if you are using windowed attention and your input token sequence is longer than the window size, than the later tokens in this pass will not attend to any "sink" tokens.

When the input token sequence is larger than the window size, then the later tokens do still attend to the sink tokens. To give a toy example, we'll consider a scenario with a window size of 10, including 4 attention sink tokens, and a text which is just a space separated alphabet. When generating, the model sees:

A
A B
A B C
A B C D
A B C D E
A B C D E F
A B C D E F G
A B C D E F G H 
A B C D E F G H I
A B C D E F G H I J
A B C D F G H I J K
A B C D G H I J K L
A B C D H I J K L M
...

So the later tokens, e.g. Z, can still attend to the sink tokens A, B, C, D.

In your tests I think this issue would mostly be avoided because the original prompt is either short before you start generating tokens, or in the case of your streaming test you end up doing a new forward pass separately for each added sub-prompt, which are relatively short.

Did you perhaps mean that if you have an input that is longer than the window size, then the first non-sink tokens will be discarded? I.e. the question becomes "What if the prompt is 2k tokens and the window size is 1024?" Because in that case, the model won't do well. It will produce normal English text like you would expect, but it won't be able to answer the question correctly if the prompt starts with a question and then 1900 tokens of context.

However it's unclear from this how well it would work if you wanted to do a single forward pass on a long text. Do you have any tests showing that this still works with this implementation?

I don't, but I think that this would not work well, assuming that I understand you correctly. Any model loaded with attention_sinks will still obey the original "context length", i.e. the amount of tokens it can reasonably look back without "forgetting". See for example questions 2 and 3 in the FAQ:

  1. Is the context window of LLMs expanded? No. The context window remains unchanged. Only the most recent tokens and attention sinks are retained, discarding middle tokens. This means the model can only process the latest tokens. The context window remains constrained by its initial pre-training. For instance, if Llama-2 is pre-trained with a context window of 4096 tokens, then the maximum cache size for an Attention Sink model on Llama-2 remains 4096.

  2. Can I input an extensive text, like a book, into an Attention Sink model for summarization? While you can input a lengthy text, the model will only recognize the latest tokens. Thus, if a book is an input, an Attention Sink model might only summarize the concluding paragraphs, which might not be very insightful. As emphasized earlier, we neither expand the LLMs' context window nor enhance their long-term memory. An Attention Sink model's strength lies in generating fluent text from recent tokens without needing a cache refresh.

Please let me know if I understood you correctly, and if I answered your question!

sam1373 commented 9 months ago

Hi Tom, I don't think you understood the issue I am describing. I did mean the sink tokens not being attended to. Consider the following situation: we are using the Mistral model with windowed flash attention, and we have a starting text of size 30k tokens for example. We do the first forward pass on these tokens, before any cache is created (as there are no prev. tokens). We want to be able to do this efficiently in one forward pass, as this is part of the draw of using windowed attention in the first place. However, if you're only adding sink tokens in the cache, the attention done within this pass will just use the local windows. You will only start using sink tokens when generating tokens after that, however you will be attending to the cached kv of the final 4k tokens of these 30k, all of which will be "poisoned" by having been created without sink attention.

There is a separate question of how much doing this large forward pass even matters. Technically information could definitely propagate through the kv from beyond the local window, but whether this actually happens with Mistral or other such models and to what extent is more complicated, and would depend on how they were trained etc. If it doesn't happen at all, then you could just as well take the 4k suffix of the text and only do the initial forward pass on that, and in that case there would also be no point in the windowed flash attention implementation.

tomaarsen commented 9 months ago

I see! Indeed, I haven't tested this approach when the original input already exceeds the window size. That would definitely be an interesting experiment - so far all of my attempts at this give OOM exceptions.

Nintorac commented 9 months ago

Oh, missed this issue, have made some comments on the subject in #14

but whether this actually happens with Mistral or other such models and to what extent is more complicated, and would depend on how they were trained etc

I tested this with a 16 token context window on Mosaic. I observed it maintain coherence beyond the context window, but it does go off track quite quickly