tomaarsen / attention_sinks

Extend existing LLMs way beyond the original training length with constant memory usage, without retraining
https://huggingface.co/blog/tomaarsen/attention-sinks
Apache License 2.0
650 stars 41 forks source link

Add QWen model + benchmark results #15

Closed Sanster closed 9 months ago

Sanster commented 9 months ago

TODO

The code for the QWen model is remote_code. I tried to add it to AutoModelForCausalLM by referring to the methods of other models(llama), but it didn't work. The AutoModelForCausalLM.from_pretrained method did not use the code in attention_sinks/models/qwen/modeling_qwen.py. So currently when I run perplexity.py, this is how I temporarily modify it:

        if "qwen" in args.model_name_or_path.lower():
            # TODO: Make AutoModelForCausalLM.from_pretrainied work for qwen
            from attention_sinks.models import QWenLMHeadModel as AutoModelForCausalLM
        else:
            from attention_sinks import AutoModelForCausalLM

Benchmarks

Experiment: Multi group attention sinks

(The modification of this part of the code is not in this PR.)

I also conducted experiments to test the use of multiple attention sinks, my goal is to reduce model memory loss beyond kv cache size, but for now, I have only run the ppl test. (which does not represent whether context memory loss has been reduced).

image

tomaarsen commented 9 months ago

Hello!

This is looking awesome! Great job. Regarding the AutoModel difficulties, this is related to all Qwen models using trust_remote_code=True, which kind of bypasses the AutoModel classes. I'd like to look into this to figure out if there's a convenient solution, but otherwise your QWenLMHeadModel solution works. I'll try to find a solution tomorrow before I merge this, if that works for you.

As for your experiments, very interesting to see that 12 sinks is actually quite a lot better (in terms of perplexity) than 4. I had expected the difference to be smaller. I'm not quite sure what the second plot represents though, the differences between many sink tokens and a large window versus 12 sink tokens and a smaller window?

tomaarsen commented 9 months ago

Finding a fix for the AutoModel also has consequences for https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat for example, which would also not work with AutoModel right now.

Sanster commented 9 months ago

Hello!

This is looking awesome! Great job. Regarding the AutoModel difficulties, this is related to all Qwen models using trust_remote_code=True, which kind of bypasses the AutoModel classes. I'd like to look into this to figure out if there's a convenient solution, but otherwise your QWenLMHeadModel solution works. I'll try to find a solution tomorrow before I merge this, if that works for you.

As for your experiments, very interesting to see that 12 sinks is actually quite a lot better (in terms of perplexity) than 4. I had expected the difference to be smaller. I'm not quite sure what the second plot represents though, the differences between many sink tokens and a large window versus 12 sink tokens and a smaller window?

  • Tom Aarsen

I originally wanted to compare the effects of with or without sink groups in the second picture, now I have combined them all together in one picture. image

tomaarsen commented 9 months ago

@Sanster I'm unable to reproduce the behaviour for transformers - the model seems fairly stable for me beyond the 2400 tokens where you encounter issues.

I've also started work on a large refactor that allows from attention_sinks import AutoModelForCausalLM to also work with Qwen. I'll be publishing that later today - it'll incorporate this awesome PR as well!

tomaarsen commented 9 months ago

This is ready as far as I'm concerned, I'll leave it open for now so you can also review my changes. I'm afraid they're kind of squashed together with the merge, but the gist is this:

  1. Use the now general inject_mixin.py file to handle injecting the Attention Sinks into models.
  2. Then, the models/qwen folder only needs the position shifting function, and no longer the modeling and configuration.

This does mean that there is no longer any from attention_sinks import QWenLMHeadModel, but you can use:

from attention_sinks import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B", device_map="auto", trust_remote_code=True)

I do have to say, I couldn't reproduce your behaviour for transformers still, that model seemed to do pretty well, just like attention_sinks. It used a ton of VRAM though.

tomaarsen commented 9 months ago

Thanks again!