Open dongxiaolong opened 1 year ago
+1
I think the attention_sinks project is a very promising technique for improving inference for long sequences for models without having to train them with alibi etc
Ya, getting rid of max_new_tokens effectively, so generation can continue as model predicts EOS token, instead of getting truncated just because of large input token counts.
Hi @WoosukKwon and @zhuohan123,
I'm going to work on this feature and I'd like to ask about injecting Attention Sink logic. We need to be able to specify it and set the number of attention sink tokens.
LLM's API allows passing *kwargs
https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py#L74 with extra configuration and later on use this configuration for every model supporting it. Is it ok from your perspective?
@sh1ng FYI https://github.com/tomaarsen/attention_sinks has a 3rd-party version which is quite clean and works on a broader set of models than the original paper's github. It's what h2oGPT uses. Just FYI.
Hi @WoosukKwon and @zhuohan123, before starting with my changes, I'd like to consult with you regarding the cache block design for Attention Sink.
To some extent, it's similar to Mistral's circular buffer for endless strimming. But we have one significant difference - Attention Sink keeps 4 tokens at the tail(left side) and we don't want to move KV cache data at all.
So Mistral does the following(let's say we have only 4 blocks(4 tokens each) and token at point time t_i just denotes by i for conciseness) and x - empty slot. |
Blocks -> | Block 0 | Block 1 | Block 2 | Block 3 |
---|---|---|---|---|---|
Initial state | 0, 1, 2, 3 | 4, 5, 6, 7 | 8, 9, 10, 11 | 12, 13, 14, x | |
Added token 15 | 0, 1, 2, 3 | 4, 5, 6, 7 | 8, 9, 10, 11 | 12, 13, 14, 15 | |
Added token 16 | 16, 1, 2, 3 | 4, 5, 6, 7 | 8, 9, 10, 11 | 12, 13, 14, 15 | |
Added token 17 | 16, 17, 2, 3 | 4, 5, 6, 7 | 8, 9, 10, 11 | 12, 13, 14, 15 |
Attention can be computed by rotating blocks, i.e. 16, 17, 2, 3 | 4, 5, 6, 7 | 8, 9, 10, 11 | 12, 13, 14, 15 | 16, 17, 2, 3, so we can compute it on the consecutive elements similar to the initial implementation.
Please confirm that the above is correct.
Attention Sink keeps a few "tail" tokens according to the pic from https://github.com/vllm-project/vllm/issues/1304#issuecomment-1791022511
So it should be(let's say we only use 2 attention sink tokens) | Blocks -> | Block 0 | Block 1 | Block 2 | Block 3 |
---|---|---|---|---|---|
Initial state | 0, 1, 2, 3 | 4, 5, 6, 7 | 8, 9, 10, 11 | 12, 13, 14, x | |
Added token 15 | 0, 1, 2, 3 | 4, 5, 6, 7 | 8, 9, 10, 11 | 12, 13, 14, 15 | |
Added token 16 | 0, 1, 16, 3 | 4, 5, 6, 7 | 8, 9, 10, 11 | 12, 13, 14, 15 | |
Added token 17 | 0, 1, 16, 17 | 4, 5, 6, 7 | 8, 9, 10, 11 | 12, 13, 14, 15 |
And rotation will not give the correct result as it will be 0, 1, 16, 17 | 4, 5, 6, 7 | 8, 9, 10, 11 | 12, 13, 14, 15| 0, 1, 16, 17.
I see a few possible solutions:
Move the attention sink tokens to the right position. | Blocks -> | Block 0 | Block 1 | Block 2 | Block 3 |
---|---|---|---|---|---|
Initial state | 0, 1, 2, 3 | 4, 5, 6, 7 | 8, 9, 10, 11 | 12, 13, 14, x | |
Added token 15 | 0, 1, 2, 3 | 4, 5, 6, 7 | 8, 9, 10, 11 | 12, 13, 14, 15 | |
Added token 16 | 16, 0, 1, 3 | 4, 5, 6, 7 | 8, 9, 10, 11 | 12, 13, 14, 15 | |
Added token 17 | 16, 17, 0, 1 | 4, 5, 6, 7 | 8, 9, 10, 11 | 12, 13, 14, 15 |
But that will cause additional memory copying of the size of the attention sink.
4\n
as shown in the original work.
WDYT?
Hey @sh1ng, I was wondering if you've started implementing attention sinks yet. If you have, would it be possible to open a PR? I'm interested in contributing to this feature
I tend to use option 3. with a few additional modifications:
BlockSize
system blocks that contain \n
*4 shifted to all possible rotations.Prepend system block(s) and "fuse" it with the first block if necessary to always keep attention sink at the right position with minimum modifications to existing code base and memory movement. | Blocks -> | Block 0 | Block 1 | Block 2 | Block 3 | Rotated chain of blocks | Comment |
---|---|---|---|---|---|---|---|
Initial state | 0, 1, 2, 3 | 4, 5, 6, 7 | 8, 9, 10, 11 | 12, 13, 14, x | NA, NA, \n(1st), \n(2nd), 0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 | prepend [NA, NA, \n(1st), \n(2nd) ] | |
Added token 15 | 0, 1, 2, 3 | 4, 5, 6, 7 | 8, 9, 10, 11 | 12, 13, 14, 15 | \n(2nd), NA, NA, \n(1st), \n(2nd), 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 | prepend and fuse [\n(2nd), NA, NA, \n(1st)] | |
Added token 16 | 16, 1, 2, 3 | 4, 5, 6, 7 | 8, 9, 10, 11 | 12, 13, 14, 15 | NA, NA, \n(1st), \n(2nd), 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 | prepend and fuse [\n(1st), \n(2nd), NA, NA] | |
Added token 17 | 16, 17, 2, 3 | 4, 5, 6, 7 | 8, 9, 10, 11 | 12, 13, 14, 15 | NA, \n(1st), \n(2nd),3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17 | prepend and fuse [NA, \n(1st), \n(2nd), NA] |
No working code yet.
I tend to use option 3. with a few additional modifications:
- Allocate
BlockSize
system blocks that contain\n
*4 shifted to all possible rotations.- These blocks get prefilled during the start-up.
Prepend system block(s) and "fuse" it with the first block if necessary to always keep attention sink at the right position with minimum modifications to existing code base and memory movement.
Blocks -> Block 0 Block 1 Block 2 Block 3 Rotated chain of blocks Comment
Initial state 0, 1, 2, 3 4, 5, 6, 7 8, 9, 10, 11 12, 13, 14, x NA, NA, \n(1st), \n(2nd), 0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 prepend [NA, NA, \n(1st), \n(2nd) ]
Added token 15 0, 1, 2, 3 4, 5, 6, 7 8, 9, 10, 11 12, 13, 14, 15 \n(2nd), NA, NA, \n(1st), \n(2nd), 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 prepend and fuse [\n(2nd), NA, NA, \n(1st)]
Added token 16 16, 1, 2, 3 4, 5, 6, 7 8, 9, 10, 11 12, 13, 14, 15 NA, NA, \n(1st), \n(2nd), 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 prepend and fuse [\n(1st), \n(2nd), NA, NA]
Added token 17 16, 17, 2, 3 4, 5, 6, 7 8, 9, 10, 11 12, 13, 14, 15 NA, \n(1st), \n(2nd),3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17 prepend and fuse [NA, \n(1st), \n(2nd), NA]
No working code yet.
@sh1ng Hi, can we keep the first block of the seq as the attention sink, and drop blocks in the middle while keep the recent blocks? New blocks is allocated for the seq, however we control the total number of the blocks allocated to the seq. No sure if we need to deal with the position info.
I think attention sink is roughly equivlent to softmax1
as mentioned here: https://www.evanmiller.org/attention-is-off-by-one.html
Not sure if it is possible to apply softmax1 on inference time without re-training the model
This issue has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this issue should remain open. Thank you!
Efficient Streaming Language Models with Attention Sinks paper These repo has already implemented it: attention_sinks streaming-llm