mit-han-lab / streaming-llm

[ICLR 2024] Efficient Streaming Language Models with Attention Sinks
https://arxiv.org/abs/2309.17453
MIT License
6.59k stars 361 forks source link

Does past_key_values be repeatedly compute? #46

Open freyamom opened 11 months ago

freyamom commented 11 months ago

Hi! Attention sink is very amazing for llm. I am confuse about past_key_values in streaming-llm. In my image, past_key_values will be recompute in every new input. But I notice past_key_values were be stored in streaming-llm by turn use_cache on. I was try my best to stroed past_key_values and reuse it in new input inference, the output will be very very strange. But the output is really good in streaming-llm. I really want to know what kind the effort you did for reuse past_key_values. Thanks a lot!

ysanimals commented 11 months ago

by the way, I have a question about how to use window attention with re-computation. Which need to re-computation? And after read code, I also find that past_key_values will be stored in streaming-llm,and what's the difference with re-computate and this code? Where can I find the code about the re-compute?

Guangxuan-Xiao commented 11 months ago

Hi!

Thank you for reaching out and expressing your interest in the streaming-llm's attention sink feature. You're right, when use_cache=True is set, the model reuses past_key_values to make subsequent inferences efficient.

If you're encountering unexpected outputs when trying to store and reuse past_key_values on your own, there could be a discrepancy in how you handle it. To ensure you're using it correctly, please refer to the examples provided in our repository:

  1. run_streaming_llama.py
  2. eval_long_ppl.py.

Guangxuan

freyamom commented 11 months ago

@Guangxuan-Xiao Thanks for your reply. I have another question about the correlation between input_length and kv cache size. Example: input_length = 100, kv cache size = 30 streamingllm inference will do:

  1. model(input[:-30])
  2. model(input[:30]) and keep kv cache do model(input[30:60], and keep kv cache do model(input[60:])

which one is correct?

freyamom commented 11 months ago

sorry mode 2. should add attention_sink for first 4 token :)

freyamom commented 11 months ago

@Guangxuan-Xiao sorry to bother you. I have another question about the correlation between input_length and kv cache size. Example: input_length = 100, kv cache size = 30 streamingllm inference will do:

  1. model(input[:-30])
  2. model(input[:30]) and keep kv cache do model(input[30:60], and keep kv cache do model(input[60:]), mode 2. should add attention_sink for first 4 token

which one is correct?