tomaarsen / attention_sinks

Extend existing LLMs way beyond the original training length with constant memory usage, without retraining
https://huggingface.co/blog/tomaarsen/attention-sinks
Apache License 2.0
650 stars 41 forks source link

Strategy for trust_remote_code? #19

Closed kmn1024 closed 9 months ago

kmn1024 commented 9 months ago

I'm interested in using attention_sinks for models such as: https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/modeling_stablelm_epoch.py

I think I can reuse much of the code for gpt_neox_pos_shift_attention_forward, but I'm wondering how inject_mixin.py would need to be changed to make the plumbing work. In your opinion, what's the cleanest way to make this change?

tomaarsen commented 9 months ago

Hello!

Great question! It indeed involves creating a directory under models for stablelm_epoch with a pos_shift.py file. This file will have a copy of the forward method of Attention in https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/modeling_stablelm_epoch.py, which can indeed be modified like done for gpt_neox. The gist is that the key rotation must be done after caching, i.e. after this snippet.

Then, https://github.com/tomaarsen/attention_sinks/blob/main/attention_sinks/inject_mixin.py must be updated in 4 ways:

  1. Add

    "stablelm_epoch": "StableLMEpochModel"

    to https://github.com/tomaarsen/attention_sinks/blob/fc335310cf9b9425ef9572365e3ee52ac0d2164a/attention_sinks/inject_mixin.py#L12 This allows the Inject mixin to place the Attention Sink KV cache on the right object.

  2. Add

    "stablelm_epoch": "Attention",

    to https://github.com/tomaarsen/attention_sinks/blob/fc335310cf9b9425ef9572365e3ee52ac0d2164a/attention_sinks/inject_mixin.py#L21 This allows the Inject mixin to update the forward method of the right class instances.

  3. Add

    "stablelm_epoch": (2, 2),

    to https://github.com/tomaarsen/attention_sinks/blob/fc335310cf9b9425ef9572365e3ee52ac0d2164a/attention_sinks/inject_mixin.py#L30. The 2's represent the seq_len dimensions of the key and value states. As you can see in this snippet, the dimensions are 2.

  4. Add

    "stablelm_epoch": stablelm_epoch_pos_shift_attention_forward,

    to https://github.com/tomaarsen/attention_sinks/blob/fc335310cf9b9425ef9572365e3ee52ac0d2164a/attention_sinks/inject_mixin.py#L93 This is the position shifting forward method that you would create.

Afterwards, you can create a benchmark_stablelm_epoch.sh here, you can just copy any of the existing files and change up the model and names slightly. After running it, you should hopefully see a nice figure like in the top of the README.

That should be all it takes!

Afterwards, you should be able to do:

from attention_sinks import AutoModelForCausalLM

model = AutoModelForCausalLM("stabilityai/stablelm-3b-4e1t", trust_remote_code=True)

See also #15 or its merge commit (fc335310cf9b9425ef9572365e3ee52ac0d2164a) for another model that requires trust_remote_code=True.

Hope this helps! I'll gladly welcome a PR for this architecture.