huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
136.11k stars 27.26k forks source link

Implement StreamingLLM/Windowed Attention with Attention Sinks #26553

Closed tomaarsen closed 1 year ago

tomaarsen commented 1 year ago

Feature request

Hello!

I would love to see StreamingLLM/ Windowed Attention with Attention Sinks implemented, as proposed in https://arxiv.org/abs/2309.17453. The primary author (@Guangxuan-Xiao) has also released the code here: https://github.com/mit-han-lab/streaming-llm And I've adapted that code to a drop-in replacement of transformers to allow people to use it: https://github.com/tomaarsen/attention_sinks (e.g.

from attention_sinks import AutoModel

model = AutoModel.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="auto")

)


schemes

The paper shows that adapting windowed attention such that the first 4 tokens of the input sequence are always in the window, allows any tested LLM (Llama 2, MPT, Falcon, Pythia) to scale to endless inputs without catastropic perplexity increases. All without doing any form of retraining. With other words, scaling any pretrained LLM to infinite sequence length is as simple as:

  1. Converting the attention to windowed attention.
  2. Using a special cache for the windowed attention that always keeps the first 4 (by default) tokens in the cache.

Using this elementary approach, the authors were able to keep various LLM models stable when feeding them with (!) 4 million tokens. image

Motivation

Maximum sequence lengths have been an important topic for a while now, with solutions ranging from RoPE to LongLoRA to YaRN, but each of these have their limits, and some also require retraining/additional training. This windowed attention with attention sinks seems to completely solve this problem, and it would be an extremely valuable addition.

I can vouch for the results in the paper. I've gotten these results for Llama 2 7B using my own implementation: llama_2_7b_ppl_vram

Your contribution

Yes. I would love to help implement this into core transformers rather than in my drop-in implementation. However, I would like to discuss:

  1. Whether this feature is a good fit for transformers.
  2. Where we store the code for converting each model (e.g. Llama, Pythia, Falcon) to windowed attention. See e.g. this file for an example.
  3. Where we store the code with applying the Attention Sink KV Cache after a forward call. see e.g. this file for an example.

The primary author of the paper has also expressed interest in a transformers implementation here.

LysandreJik commented 1 year ago

Hey @tomaarsen, very cool feature and implementation!

This definitely looks like a good fit for transformers, or at least it should be of very high value for the community to have access to attention sinks very easily.

Keeping a drop-in implementation up to date on the long term is hard to do, so I would recommend we move towards a utility function for now that could eventually be upstreamed into transformers once it has developed a bit more.

So instead of the current

from attention_sinks import AutoModel

model = AutoModel.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="auto")

how about something like

from attention_sinks import convert_model_attention_to_sinks
from transformers import AutoModel

model = AutoModel.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="auto")
model = convert_model_attention_to_sinks(model)

?


Eventually, the API could take two different directions:

  1. Either we develop it similarly to the existing BetterTransformers support -> It depends on the optimum library being installed in the environment, and offers the method model.to_bettertransformers() to convert the model to the right format
  2. Either we add support for it directly in the from_pretrained method like we do for Flash Attention: AutoModel.from_pretrained("meta-llama/Llama-2-7b-hf", use_flash_attention_2=True)

The first path is likely the most scalable; we would work a bit on the model definition to enable "plugins" from third-party library, enabling support for many third-party tools. The second one would offer support in core transformers directly, but for this we would really want validation across many models first.

cc @ArthurZucker @younesbelkada @patrickvonplaten @ydshieh

tomaarsen commented 1 year ago

Hello!

I'm glad that you've open to the idea of adding this to transformers! I think it would be of enormous value. First of all, I agree that the implementation of

from attention_sinks import AutoModel

model = AutoModel.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="auto")

is not workable long-term, but I think a solution similar to BetterTransformers in optimum is viable. People can approach the third-party application (e.g. optimum in the case of BetterTransformers, or attention_sinks), and propose the conversion method to add Attention Sinks to whatever architecture isn't supported yet. The goal/scope of the third party would then essentially be to act as a dictionary mapping architectures to conversion functions (rather than also providing AutoModel, AutoModelForCausalLM, LlamaModel, etc.).

On transformers we would have a conversion method (e.g. add_attention_sinks), which applies the conversion from the third party, if it exists. This might be preferable from an API perspective to your option 2, as this method can be given args and kwargs, such as the attention_sink_size (e.g. the first 4 tokens) and window_size (e.g. 1020 tokens). Adding more args and kwargs is more scalable in this way, as we can't just willy-nilly add these kwargs to transformers AutoModel.from_pretrained. This is important to consider, as the research on this is extremely new - so we might require more arguments in the future as the research expands.

I'm curious to hear your thoughts on this.

For reference, today I will be adding support for MPT, Falcon, Pythia alongside my existing Llama support to attention_sinks.

patrickvonplaten commented 1 year ago

[Brainstorming] I'm wondering whether we could use this issue as a catalyst to improve our cache / past key value design we have in Transformers as it needs to be updated anyways soon (cc @gante as well).

@tomaarsen do you think we could support StreamingLLM to every model just by defining a "StreamingLLM/AttentionSink" cache that can be passed to the forward method (as past_key_values) and that would then take care of correctly creating the past key values at each step.

Here a GitHub gist of what I'm thinking of: https://gist.github.com/patrickvonplaten/7411f84b8a2cca3bc8e63df315d7d618

In short, this would entail some more fundamental changes to Transformers (essentially that every attention layer would call cache.update(...) if past_key_values is an object of type Cache), but I think this is something we want to do anyways to allow for torch.compile to work better. Also we would then give generate a new function argument generate(..., cache=cache) that can be optionally be passed.

Would be curious to hear what you think about this idea! At this stage is definitely still pure brainstorming, but I think this could be a cool long-term solution that would also be quite easy to implement

tomaarsen commented 1 year ago

@patrickvonplaten

I'm afraid that your proposal would not quite be enough to implement the AttentionSink approach in all models. In addition to the cache, the approach requires that the position IDs are shifted in the window. To give a toy example: 4 attention sink tokens, window size of 6, and the text is just a space separated alphabet, then the model sees:

A
A B
A B C
A B C D
A B C D E
A B C D E F
A B C D E F G
A B C D E F G H 
A B C D E F G H I
A B C D E F G H I J
A B C D F G H I J K
A B C D G H I J K L
A B C D H I J K L M
...

With these position IDs:

0
0 1
0 1 2
0 1 2 3
0 1 2 3 4
0 1 2 3 4 5
0 1 2 3 4 5 6
0 1 2 3 4 5 6 7
0 1 2 3 4 5 6 7 8
0 1 2 3 4 5 6 7 8 9
0 1 2 3 4 5 6 7 8 9
0 1 2 3 4 5 6 7 8 9
0 1 2 3 4 5 6 7 8 9
...

i.e. the position IDs get shifted (or rather, they don't get shifted) as the window moves.

Or from the paper itself (Section 3.2, page 5):

When determining the relative distance and adding positional information to tokens, StreamingLLM focuses on positions within the cache rather than those in the original text. This distinction is crucial for StreamingLLM’s performance. For instance, if the current cache has tokens [0, 1, 2, 3, 6, 7, 8] and is in the process of decoding the 9th token, the positions assigned are [0, 1, 2, 3, 4, 5, 6, 7], rather than the positions in the original text, which would be [0, 1, 2, 3, 6, 7, 8, 9].


In practice, this is somewhat simple. For Mistral it requires changing this rotary position embedding application here: https://github.com/huggingface/transformers/blob/2c7b26f5083becb429bdae4c919feca28fdf5699/src/transformers/models/mistral/modeling_mistral.py#L273 Into one that only updates the query_states, e.g.

query_states = apply_rotary_pos_emb_single(query_states, cos, sin, position_ids)

with

def apply_rotary_pos_emb_single(x, cos, sin, position_ids):
    # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
    cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
    sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    x_embed = (x * cos) + (rotate_half(x) * sin)
    return x_embed

Then, we update the key and value states using the cache, followed by an update to the cache. Only after that's done, do we update the key_states with "faked" position IDs:

key_position_ids = torch.arange(kv_seq_len, device=position_ids.device).unsqueeze(0)
key_states = apply_rotary_pos_emb_single(key_states, cos, sin, key_position_ids)

I took these snippets from my attention_sinks here and here. I'd recommend checking out these sources as these snippets might be confusing without their context.


The tl:dr is essentially that we need 2 changes to implement Attention Sink correctly:

  1. Update the model architecture to shift the position IDs.
  2. Update the Attention Sink KV Cache using the past_key_values from every ...Model.forward call.

Your proposal would be a very elegant solution for the second part of the implementation, but not yet the former. I do the former in my pos_shift.py files for Mistral, Falcon, GPT-NeoX and Llama.

Sidenote: I added support for Mistral, GPT-NeoX, Falcon and MPT to attention_sinks 🎉 If the model perplexities are anything to go by, then it works great for everything that I've tried:

Perplexity & VRAM plots | Llama 2 7B | Falcon 7B | |:-------------:|:-------------:| | ![llama_2_7b_ppl_vram_plotted](https://github.com/tomaarsen/attention_sinks/assets/37621491/8d2e5b88-7158-41ac-8b3a-5a7abe38020d) | ![falcon_7b_ppl_vram_plotted](https://github.com/tomaarsen/attention_sinks/assets/37621491/1be07370-6de7-4a7e-b5ab-3092a5ecb412) | | **MPT 7B** | **Pythia 6.9B** | | ![mpt_7b_ppl_vram_plotted](https://github.com/mit-han-lab/streaming-llm/assets/37621491/c96cff66-92a3-43ab-bc21-40232f2740a0) | ![pythia_6 8b_ppl_vram_plotted](https://github.com/tomaarsen/attention_sinks/assets/37621491/b0fee168-fa5a-457d-9e27-8395eb6dfb38) | | **Mistral 7B** | | | ![mistral_7b_ppl_vram_plotted](https://github.com/microsoft/torchscale/assets/37621491/3a4c5634-cc1b-42d1-a35a-afb376a4f970) | |
patrickvonplaten commented 1 year ago

Great point about the position_ids, I indeed didn't think about this enough.

Also super nice to see that the other LLMs also work great with StreamingLLM's approach! Very encouraging!

Taking a step back here, I guess there are different levels of support we could offer for StreamingLLM:



We can achieve this by following the design as described here

for the cache, i.e.:

  1. Update the Attention Sink KV Cache using the past_key_values from every ...Model.forward call.

Now for:

  1. Update the model architecture to shift the position IDs.

it's indeed trickier!

One approach to handle this here could be to add an optional key_position_ids function argument here: https://github.com/huggingface/transformers/blob/2f3ea08a077ba3133fa8a604b22436cad250b055/src/transformers/models/falcon/modeling_falcon.py#L429

This would then propagate all the way to apply_rotary_pos_emb: https://github.com/huggingface/transformers/blob/2f3ea08a077ba3133fa8a604b22436cad250b055/src/transformers/models/llama/modeling_llama.py#L207

that would default to position_ids if not specified. This way the user could at every forward call for generate pass the correct, but different position_ids for query and key respectively. For the user this could then look as follows:

cache = SinkCache(window_length=256, num_sink_tokens=3)

query_pos_ids = ...
key_pos_ids = ....
model(input_ids, position_ids=query_pos_ids, key_position_ids=key_pos_ids, cache=cache)

cache = SinkCache(window_length=256, num_sink_tokens=3)

to generate:

model.generate("....", cache=cache)

and regarding the position_ids it would also require only a small change in https://github.com/huggingface/transformers/blob/2f3ea08a077ba3133fa8a604b22436cad250b055/src/transformers/models/llama/modeling_llama.py#L1093 to correct the position ids for generation.

Questions:

Changing all the position_ids logic just for StreamingLLM might be a tough sell given that the method is still relatively new, but if it can be nicely done I think it'd be ok to extend the forward and prepare_inputs_for_generation method.

What do you think here @tomaarsen ?

In that regard some questions:

tomaarsen commented 1 year ago

Apologies for the delay, it a busy day at work today. I'll go over each of your options:

1.) No real native support in Transformers

Although I'm definitely open to maintaining a third party package, it is not feasible for transformers as it stands right now. For each architecture I have to:

  1. ✅ Wrap the forward of the ...Model with a cache update, which I can implement fairly elegantly.
  2. ❌ Completely replace the entire forward method of all ...Attention classes to update the position IDs.

This requires me to completely pin each attention_sinks version to a very specific transformers version, which is not really viable, as much as I think it would be fun to maintain a plugin of transformers.


2.) Native support in Transformers but only for a "model.forward-level".

This is my personal preference intuitively.

Beyond that, generation doesn't work out of the box even with the forward methods correctly updated. I've encountered two noteworthy problems:

  1. The attention_mask in _update_model_kwargs_for_generation grows with a "1" for every token generated. Once the Sink KV Cache starts removing samples then this causes a shape mismatch. Easy fix here: https://github.com/tomaarsen/attention_sinks/blob/f46e63101fa74c6095e986c33284217c34a9fd88/attention_sinks/generation/utils.py#L38-L41

  2. The model.generate method does not return the past_key_values, preventing any form of multi-step generation (which is the primary use case of the Attention Sink approach: being able to keep prompting your model over and over and over without it losing fluency). If we update the cache like discussed prior, then this problem could be resolved by the user passing a Cache instance to model.generate which holds the updated past_key_values. This cache instance can then be reused for future model.generate calls.


I think that the key_position_ids idea should work. An alternative is that rotating and caching is implemented in a method, so that only this method can be overridden by a third party (i.e. "attention_sinks") to provide this functionality.

Edit: Another alternative is a parameter on the cache class for cache_before_rotate which determines whether to cache before (like in Attention Sink) or after (normal) rotating.


As for your questions, I also invite @Guangxuan-Xiao to answer here, but I'll do my best to answer:

And some more info on the implementation:

For encoding like RoPE, we cache the Keys of tokens prior to introducing the rotary transformation. Then, we apply position transformation to the keys in the rolling cache at each decoding phase. On the other hand, integrating with ALiBi is more direct. Here, the contiguous linear bias is applied instead of a ’jumping’ bias to the attention scores. This method of assigning positional embedding within the cache is crucial to StreamingLLM’s functionality, ensuring that the model operates efficiently even beyond its pre-training attention window size.

For GPT2, I'd have to have a quick look at the implementation. I see it's a bit different than the modern LLMs, e.g. with GPT2LMHeadModel instead of GPT2ForCausalLM. However, I think we need rotary embeddings.


Quasi-related: I've been pointed to a similar paper that does something very similar: https://arxiv.org/abs/2308.16137

It involves only a Λ-shaped attention mask (to avoid excessive attended tokens) and a distance limit (to avoid unseen distances) while requiring no parameter updates or learning. We find it applicable to a variety of LLMs using relative-position encoding methods. LM-Infinite is computationally efficient with O(n) time and space, and demonstrates consistent text generation fluency and quality to as long as 128k tokens on ArXiv and OpenWebText2 datasets, with 2.72x decoding speedup. We will make the codes publicly available following publication.

This "Λ-shaped attention mask" is kind of like always attending to the first tokens (i.e. the sink tokens) and "a distance limit" sounds like a window size.

Glaciohound commented 1 year ago

Thanks for mentioning our work (https://arxiv.org/abs/2308.16137) "LM-Infinite: Simple On-the-Fly Length Generalization for Large Language Models" a month ago! I also noticed the striking similarities between the two methods: (1) we both use a $\Lambda$-shaped attention mask, which is equivalent to "sink tokens" + nearest tokens, and (2) we both re-arrange the distance, which we referred to as a "distance limit" while they refer to as "When determining the relative distance and adding positional information to tokens, StreamingLLM focuses on positions within the cache rather than those in the original text" in Section 3.2.

We are happy to share an implementation here: https://github.com/Glaciohound/LM-Infinite, which you might be interested in having a check.

Somewhat surprisingly, in the StreamingLLM's implementation, even when doing context encoding (such as calculating perplexity of a sequence), they feed tokens one by one (as can be observed here and here). In the contrary, our implementation offers a "sequence" mode encoding functionality just as normal language models, which avoids looping through the sequence and provide a great computational efficiency. This is thanks to our specialized attention kernel implementation.

I am also very interested in helping to integrate these papers in HuggingFace Transformers. If you need any further information or help from technical side, please do not hesitate to let me know.

patrickvonplaten commented 1 year ago

Also @gante

tomaarsen commented 1 year ago

@patrickvonplaten I've created a draft PR in #26681 using the Cache that you envisioned. The implementation for the Attention Sink Cache should be fairly simple then.

Also, I ran more experiments over the weekend:

I have a Hugging Face blogpost with more experiments on Attention Sinks coming out soon.

Glaciohound commented 1 year ago

@tomaarsen @patrickvonplaten

This is awesome! In this way, the PR provides a general cache module reusable for other models as well, which is of great help to the whole community and future developers for other models.

What is left to be done is compatibility with backward and sequence forwarding/classification support for long sequences, which I am more than happy to help on! Current implementation here is optimized for generation. To also let users forwarding and backwarding long sequences (such as encoding long contexts or for classification on long document, an inevitable need when users do large-scale pre-training or deployment) without token-by-token forwards, our code snippet used in LM-Switch can serve as a starting point for encoding (happy to merge our codes!). After that, Sinks/StreamingLLM can continue using the cached features (theoretically compatible) for generation.

tomaarsen commented 1 year ago

I'd love to continue working on this.

ydshieh commented 1 year ago

Of course @tomaarsen .

You can delete the bot comment (I guess you know it 😄 ) - and welcome to the team!

tomaarsen commented 1 year ago

Thank you! ❤️

woominsong commented 11 months ago

@tomaarsen @Glaciohound Hi! Thanks for all the efforts you have put into making this work. I was wondering if there have been any updates regarding this issue, particularly about forwarding long sequences.

Thanks in advance!