Open why-in-Shanghaitech opened 3 days ago
If we make some ad-hoc changes: https://github.com/huggingface/transformers/blob/52daf4ec768fb9ffe84a0c373834172a7c54aecc/src/transformers/models/llama/modeling_llama.py#L962 https://github.com/huggingface/transformers/blob/52daf4ec768fb9ffe84a0c373834172a7c54aecc/src/transformers/models/llama/modeling_llama.py#L967
Make this two lines always true. Then it can output the almost same texts as in the original streamingllm implementation.
It might be related to the changes in #32315, but this PR will report errors from my end.
https://github.com/huggingface/transformers/pull/33692/commits/45c70a70c1588ae839a097088d8b36e4a688d859 This commit tries to solve the problem above. I am not sure whether this will affect codes in other places...
@why-in-Shanghaitech yes, SinkCache currently is very heavily bugged and the tests you mentioned are failing. We know when it started bugging, when we modified attention mask preparation stages. Prev the rerotations worked because attention mask (and consequently position ids) never went beyond the
window_length
Currently my idea was to bring back the cropping as keeping cos/sin until the maximum possible sequence length seemed a bit inefficient in terms of memory. But I see it is also a possible solution. Since I paused work on SinkCache for a few weeks for focus on other tasks, tagging @gante
@gante will say what is better for sink cache, he is the point of contact for anything in generation
Thank you for the reply!
Yes... I think cropping is necessary since the attention mask grows w/o looking at the past key values.
Maybe we can crop here instead? The past key values are known in model_kwargs
. Or maybe we should process the attention mask in the Cache classes?
Also, since we cannot modify the positional encoding of the queries, it seems impossible to keep the cos/sin within the window length. And I think we don't need to do this? This PR only caches cos/sin with window length + the longest prompt length (not the entire generated sequence length).
What does this PR do?
Fixes #33691
This PR shifts the new key states along with those already in the cache. It could also solve some issues mentioned in #31537, #32315, where the initial input key states are longer than the context window or shorter than the attn sink.
I hope it could better explain the issue in #33691.
It does not solve the issue that the new key states have large positional encodings (See #33691 the last few lines). So it still does not work with
model.generate
. Once this is fixed, the model will be able to generate tokens very well.Also, it will be better if we can come up with some good tests. E.g. re-design key state so that it takes the actual positional encoding as its attributes in the test?
Before submitting
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.
@ArthurZucker @zucchini-nlp @gante I didn't finish up reading all the previous discussions (it's too long)... But I know you should be very familiar about this.