huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132.7k stars 26.44k forks source link

[WIP] Sink cache: fix implementation to shift key states #33692

Open why-in-Shanghaitech opened 3 days ago

why-in-Shanghaitech commented 3 days ago

What does this PR do?

Fixes #33691

This PR shifts the new key states along with those already in the cache. It could also solve some issues mentioned in #31537, #32315, where the initial input key states are longer than the context window or shorter than the attn sink.

I hope it could better explain the issue in #33691.

It does not solve the issue that the new key states have large positional encodings (See #33691 the last few lines). So it still does not work with model.generate. Once this is fixed, the model will be able to generate tokens very well.

Also, it will be better if we can come up with some good tests. E.g. re-design key state so that it takes the actual positional encoding as its attributes in the test?

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.

@ArthurZucker @zucchini-nlp @gante I didn't finish up reading all the previous discussions (it's too long)... But I know you should be very familiar about this.

why-in-Shanghaitech commented 3 days ago

If we make some ad-hoc changes: https://github.com/huggingface/transformers/blob/52daf4ec768fb9ffe84a0c373834172a7c54aecc/src/transformers/models/llama/modeling_llama.py#L962 https://github.com/huggingface/transformers/blob/52daf4ec768fb9ffe84a0c373834172a7c54aecc/src/transformers/models/llama/modeling_llama.py#L967

Make this two lines always true. Then it can output the almost same texts as in the original streamingllm implementation.

It might be related to the changes in #32315, but this PR will report errors from my end.

why-in-Shanghaitech commented 2 days ago

https://github.com/huggingface/transformers/pull/33692/commits/45c70a70c1588ae839a097088d8b36e4a688d859 This commit tries to solve the problem above. I am not sure whether this will affect codes in other places...

why-in-Shanghaitech commented 2 days ago

@why-in-Shanghaitech yes, SinkCache currently is very heavily bugged and the tests you mentioned are failing. We know when it started bugging, when we modified attention mask preparation stages. Prev the rerotations worked because attention mask (and consequently position ids) never went beyond the window_length

Currently my idea was to bring back the cropping as keeping cos/sin until the maximum possible sequence length seemed a bit inefficient in terms of memory. But I see it is also a possible solution. Since I paused work on SinkCache for a few weeks for focus on other tasks, tagging @gante

@gante will say what is better for sink cache, he is the point of contact for anything in generation

Thank you for the reply!

Yes... I think cropping is necessary since the attention mask grows w/o looking at the past key values.

https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/generation/utils.py#L645-L649

Maybe we can crop here instead? The past key values are known in model_kwargs. Or maybe we should process the attention mask in the Cache classes?

Also, since we cannot modify the positional encoding of the queries, it seems impossible to keep the cos/sin within the window length. And I think we don't need to do this? This PR only caches cos/sin with window length + the longest prompt length (not the entire generated sequence length).

https://github.com/why-in-Shanghaitech/transformers/blob/1550f8f126946570c748131552c6b97fd33ba874/src/transformers/cache_utils.py#L969-L971