mit-han-lab / streaming-llm

[ICLR 2024] Efficient Streaming Language Models with Attention Sinks
https://arxiv.org/abs/2309.17453
MIT License
6.38k stars 355 forks source link

Comparison with SWA in Mistral #24

Open casper-hansen opened 9 months ago

casper-hansen commented 9 months ago

Hi @Guangxuan-Xiao, do you have any comparison with sliding window attention from Mistral? The paper only describes SWA with re-computation which is not how it works in the new models.

Sliding Window with Re-computation rebuilds the KV states from the L recent tokens for each new token.

Basically, this is not what they do in the Mistral model. They do not rebuild the KV states, they evict the oldest part of the cache in favor of the newest parts.

Guangxuan-Xiao commented 9 months ago

Hi, please check my explanation at https://github.com/mit-han-lab/streaming-llm/issues/33#issuecomment-1758597666, and let me know if you have any further questions!

verlocks commented 9 months ago

Hi @Guangxuan-Xiao, thanks for your explanation! However, it seems you didn't mention SWA in Mistral model? In Mistral model, it utilized Sliding Window Attention when inferencing and I believe it doesn't recompute during inference, and I am wondering why it can achieve this, because in your paper, the performance of model degenerates when using Window Attention.

I am currently thinking maybe it is because Mistral model was trained with Sliding Window Attention, and in result it avoided the attention sink phenomenon. (In one of their issue, this is asked but not answered yet)

tomaarsen commented 9 months ago

For reference, the Mistral model degrades in performance over time just like dense attention methods: 272347418-3a4c5634-cc1b-42d1-a35a-afb376a4f970 Here, attention_sinks refers to the StreamingLLM approach, transformers is their model used via the transformers library, and windowed is simple window attention with position ID shifting.

Furthermore, when giving it subsequent prompts (160 prompts in a row): 274319361-987513d9-75d6-41e6-96a5-5d47624faed3

[!NOTE] The automatic detection of fluency losses is very naive: it tries to count the number of real words in the response, but that can result in false positives if e.g. the prompt is to generate some German text. See demo/streaming_logs for the full logs to get a better picture of the real generative performance.

E.g. Mistral for transformers and attention_sinks - it's a big difference after like 250 lines.

hmzo commented 9 months ago

In my opinion, the "sliding window attention" mentioned in Mistral is equivalent to the "window attention" mentioned in attention_sinks.

casper-hansen commented 9 months ago

@tomaarsen I see your point here. My point was more so towards the latency reported in the paper.

Also more interestingly would be a comparison between vLLM/TGI with and without attention sinks since nobody uses raw Huggingface generate methods in production.

I wish the author of the paper had compared with how sliding window was actually used because it has no recomputation overhead like it’s presented in the paper.

dengxiaotian123 commented 7 months ago

Hi @Guangxuan-Xiao, thanks for your explanation! However, it seems you didn't mention SWA in Mistral model? In Mistral model, it utilized Sliding Window Attention when inferencing and I believe it doesn't recompute during inference, and I am wondering why it can achieve this, because in your paper, the performance of model degenerates when using Window Attention.

I am currently thinking maybe it is because Mistral model was trained with Sliding Window Attention, and in result it avoided the attention sink phenomenon. (In one of their issue, this is asked but not answered yet)

Hello ,@verlocks I want to ask a question. In the 'one_file_ref.py' script of 'mistrail', it seems that sliding_window was used during training, but not during inference (because input_ids.shape[-1] should be 1 during inference). Is the above understanding correct?

ehuaa commented 4 months ago

For reference, the Mistral model degrades in performance over time just like dense attention methods: 272347418-3a4c5634-cc1b-42d1-a35a-afb376a4f970 Here, attention_sinks refers to the StreamingLLM approach, transformers is their model used via the transformers library, and windowed is simple window attention with position ID shifting.

Furthermore, when giving it subsequent prompts (160 prompts in a row): 274319361-987513d9-75d6-41e6-96a5-5d47624faed3

Note

The automatic detection of fluency losses is very naive: it tries to count the number of real words in the response, but that can result in false positives if e.g. the prompt is to generate some German text. See demo/streaming_logs for the full logs to get a better picture of the real generative performance.

E.g. Mistral for transformers and attention_sinks - it's a big difference after like 250 lines.

Hi @tomaarsen, It's a bit weird that in transformers's official api doc,https://huggingface.co/docs/transformers/en/model_doc/mistral mistral has a maximum input length of almost 128k, Mistral’s sliding window attention allows sequence of up to 4096*32 tokens. but in your test, when the input length grows to 8k, it failed. Is this right?

tomaarsen commented 4 months ago

when the input length grows to 8k, it failed. Is this right?

That's right. Although the model doesn't crash until 128k, it doesn't perform well once it has exceeded the pretraining size of 8k tokens.

ehuaa commented 4 months ago

when the input length grows to 8k, it failed. Is this right?

That's right. Although the model doesn't crash until 128k, it doesn't perform well once it has exceeded the pretraining size of 8k tokens.

Thanks for your quick reply, so for industrial use, input exceeded the pretraning size of 8k will not work for mistral model.

tomaarsen commented 4 months ago

Correct, not for mistralai/Mistral-7B-v0.1, at least. There are some Mistral-based models that work on longer sequence lengths, e.g.: https://huggingface.co/NousResearch/Yarn-Mistral-7b-128k

ehuaa commented 4 months ago

Correct, not for mistralai/Mistral-7B-v0.1, at least. There are some Mistral-based models that work on longer sequence lengths, e.g.: https://huggingface.co/NousResearch/Yarn-Mistral-7b-128k

Thanks tom, i'll check the url later!

ehuaa commented 4 months ago

Correct, not for mistralai/Mistral-7B-v0.1, at least. There are some Mistral-based models that work on longer sequence lengths, e.g.: https://huggingface.co/NousResearch/Yarn-Mistral-7b-128k

Hi @tomaarsen , i have another problem here. In your test above, with the config in Mistral sliding_window equals 4096, when the input length grows to 8k, it still has a reasonable perplexity. But in attention sink paper, it says "Window attention collapses once the input length exceeds the cache size, i.e., the initial tokens are evicted". but in mistral when the input length larger than 4096, the model doesn't suddenly failed, is there something new fintuned with Mistral model with sliding window? Can you help me figure this out, thanks!