vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
26.08k stars 3.82k forks source link

Could you support Attention Sink? #1304

Open dongxiaolong opened 11 months ago

dongxiaolong commented 11 months ago

Efficient Streaming Language Models with Attention Sinks paper These repo has already implemented it: attention_sinks streaming-llm

creatorrr commented 10 months ago

+1

I think the attention_sinks project is a very promising technique for improving inference for long sequences for models without having to train them with alibi etc

pseudotensor commented 10 months ago

Ya, getting rid of max_new_tokens effectively, so generation can continue as model predicts EOS token, instead of getting truncated just because of large input token counts.

sh1ng commented 10 months ago

Hi @WoosukKwon and @zhuohan123,

I'm going to work on this feature and I'd like to ask about injecting Attention Sink logic. We need to be able to specify it and set the number of attention sink tokens. Screenshot 2023-11-02 at 16-40-49 2309 17453 pdf LLM's API allows passing *kwargs https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py#L74 with extra configuration and later on use this configuration for every model supporting it. Is it ok from your perspective?

pseudotensor commented 10 months ago

@sh1ng FYI https://github.com/tomaarsen/attention_sinks has a 3rd-party version which is quite clean and works on a broader set of models than the original paper's github. It's what h2oGPT uses. Just FYI.

sh1ng commented 10 months ago

Hi @WoosukKwon and @zhuohan123, before starting with my changes, I'd like to consult with you regarding the cache block design for Attention Sink.

To some extent, it's similar to Mistral's circular buffer for endless strimming. But we have one significant difference - Attention Sink keeps 4 tokens at the tail(left side) and we don't want to move KV cache data at all.

So Mistral does the following(let's say we have only 4 blocks(4 tokens each) and token at point time t_i just denotes by i for conciseness) and x - empty slot. Blocks -> Block 0 Block 1 Block 2 Block 3
Initial state 0, 1, 2, 3 4, 5, 6, 7 8, 9, 10, 11 12, 13, 14, x
Added token 15 0, 1, 2, 3 4, 5, 6, 7 8, 9, 10, 11 12, 13, 14, 15
Added token 16 16, 1, 2, 3 4, 5, 6, 7 8, 9, 10, 11 12, 13, 14, 15
Added token 17 16, 17, 2, 3 4, 5, 6, 7 8, 9, 10, 11 12, 13, 14, 15

Attention can be computed by rotating blocks, i.e. 16, 17, 2, 3 | 4, 5, 6, 7 | 8, 9, 10, 11 | 12, 13, 14, 15 | 16, 17, 2, 3, so we can compute it on the consecutive elements similar to the initial implementation.

Please confirm that the above is correct.

Attention Sink keeps a few "tail" tokens according to the pic from https://github.com/vllm-project/vllm/issues/1304#issuecomment-1791022511

So it should be(let's say we only use 2 attention sink tokens) Blocks -> Block 0 Block 1 Block 2 Block 3
Initial state 0, 1, 2, 3 4, 5, 6, 7 8, 9, 10, 11 12, 13, 14, x
Added token 15 0, 1, 2, 3 4, 5, 6, 7 8, 9, 10, 11 12, 13, 14, 15
Added token 16 0, 1, 16, 3 4, 5, 6, 7 8, 9, 10, 11 12, 13, 14, 15
Added token 17 0, 1, 16, 17 4, 5, 6, 7 8, 9, 10, 11 12, 13, 14, 15

And rotation will not give the correct result as it will be 0, 1, 16, 17 | 4, 5, 6, 7 | 8, 9, 10, 11 | 12, 13, 14, 15| 0, 1, 16, 17.

I see a few possible solutions:

  1. Modify paged_attention_kernel to handle above case. I'm not sure yet that it is feasible and will not affect the performance.
  2. Move the attention sink tokens to the right position. Blocks -> Block 0 Block 1 Block 2 Block 3
    Initial state 0, 1, 2, 3 4, 5, 6, 7 8, 9, 10, 11 12, 13, 14, x
    Added token 15 0, 1, 2, 3 4, 5, 6, 7 8, 9, 10, 11 12, 13, 14, 15
    Added token 16 16, 0, 1, 3 4, 5, 6, 7 8, 9, 10, 11 12, 13, 14, 15
    Added token 17 16, 17, 0, 1 4, 5, 6, 7 8, 9, 10, 11 12, 13, 14, 15

    But that will cause additional memory copying of the size of the attention sink.

  3. Keep attention sink KV in a separate block while maintaining a circular buffer of streamed tokens and then compute the attention of every new token with the block. We can even try to reuse the block for all requests with 4\n as shown in the original work. Screenshot 2023-11-08 at 18-19-03 2309 17453 pdf

WDYT?

felixzhu555 commented 6 months ago

Hey @sh1ng, I was wondering if you've started implementing attention sinks yet. If you have, would it be possible to open a PR? I'm interested in contributing to this feature

sh1ng commented 6 months ago

I tend to use option 3. with a few additional modifications:

No working code yet.

ChuanhongLi commented 5 months ago

I tend to use option 3. with a few additional modifications:

  • Allocate BlockSize system blocks that contain \n*4 shifted to all possible rotations.
  • These blocks get prefilled during the start-up.
  • Prepend system block(s) and "fuse" it with the first block if necessary to always keep attention sink at the right position with minimum modifications to existing code base and memory movement.

    Blocks -> Block 0 Block 1 Block 2 Block 3 Rotated chain of blocks Comment

    Initial state 0, 1, 2, 3 4, 5, 6, 7 8, 9, 10, 11 12, 13, 14, x NA, NA, \n(1st), \n(2nd), 0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 prepend [NA, NA, \n(1st), \n(2nd) ]

    Added token 15 0, 1, 2, 3 4, 5, 6, 7 8, 9, 10, 11 12, 13, 14, 15 \n(2nd), NA, NA, \n(1st), \n(2nd), 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 prepend and fuse [\n(2nd), NA, NA, \n(1st)]

    Added token 16 16, 1, 2, 3 4, 5, 6, 7 8, 9, 10, 11 12, 13, 14, 15 NA, NA, \n(1st), \n(2nd), 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 prepend and fuse [\n(1st), \n(2nd), NA, NA]

    Added token 17 16, 17, 2, 3 4, 5, 6, 7 8, 9, 10, 11 12, 13, 14, 15 NA, \n(1st), \n(2nd),3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17 prepend and fuse [NA, \n(1st), \n(2nd), NA]

No working code yet.

@sh1ng Hi, can we keep the first block of the seq as the attention sink, and drop blocks in the middle while keep the recent blocks? New blocks is allocated for the seq, however we control the total number of the blocks allocated to the seq. No sure if we need to deal with the position info.

Atry commented 4 months ago

I think attention sink is roughly equivlent to softmax1 as mentioned here: https://www.evanmiller.org/attention-is-off-by-one.html

Not sure if it is possible to apply softmax1 on inference time without re-training the model