huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
128.29k stars 25.45k forks source link

A question about the implementation of Sinkcache. #31581

Closed kisseternity closed 3 days ago

kisseternity commented 3 days ago

System Info

Ubuntu

Who can help?

No response

Information

Tasks

Reproduction

Review the code in the cache_utils.py, for the implementation of SinkCache.

Expected behavior

In the original paper, it indicates the position ids related to k matrix should be recomputed to focus on positions within the cache rather than those in the original text. image

However in the implementation code, it seems that the last key_states value uses the original position id instead of recomputation, as the following code shows: image

Please take a look if I am right, thanks.

zucchini-nlp commented 3 days ago

Hey!

Yes, Sink Cache should use position ids inside the cache. For that we firstly rerotate the prev key/values that will remain in cache here. In other words if the cache size is 10 and it's full, we will have to discard the oldest token (except sink tokens) and add a new one at the end. That causes shift by one for all prev tokens except for sink tokens

The new key/value that we are appending at the end already has a correct position applied to it. If we use Llama model as an example, RoPE is applied inside Attention module just before sending kv to the cache.

The position ids we used in RoPE are prepared here, because when we use generate() we always prepare inputs first by creating position ids from scratch (unless position ids were passed by user into generate(), but that's a known limitation which will be fixed soon). So we can be sure that position ids will not get more than cache max length.

Let me know if that makes sense. Also, please note that this kind of questions are better placed in the forum :)

kisseternity commented 3 days ago

Thanks for your detailed explanation. I ignore the code in _prepare_inputs_forgeneration func.