tomaarsen / attention_sinks

Extend existing LLMs way beyond the original training length with constant memory usage, without retraining
https://huggingface.co/blog/tomaarsen/attention-sinks
Apache License 2.0
650 stars 41 forks source link

Add `model.generate` support #6

Closed tomaarsen closed 10 months ago

tomaarsen commented 10 months ago

Closes #1

Hello!

Pull Request overview

Details

The _update_model_kwargs_for_generation method in GenerationMixin would endlessly grow the attention_mask to match the past_key_values + 1, which is normally very reasonable. However, with attention_sinks we eventually cap the past_key_values, so it ended up crashing.

This change very simply prevents the endless growth of the attention_mask so it always matches past_key_values + 1.

Usage

import torch
from transformers import AutoTokenizer, TextStreamer, GenerationConfig
from attention_sinks import AutoModelForCausalLM

# model_id = "meta-llama/Llama-2-7b-hf"
# model_id = "mistralai/Mistral-7B-v0.1"
model_id = "mosaicml/mpt-7b"
# model_id = "tiiuae/falcon-7b"
# model_id = "EleutherAI/pythia-6.9b-deduped"

# Load the chosen model and corresponding tokenizer
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    # for efficiency:
    device_map="auto",
    torch_dtype=torch.float16,
    # `attention_sinks`-specific arguments:
    attention_sink_size=4,
    attention_sink_window_size=252,
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token_id = tokenizer.eos_token_id

# Our input text
text = "Vaswani et al. (2017) introduced the Transformers"

# Encode the text
input_ids = tokenizer.encode(text, return_tensors="pt").to(model.device)

# Print tokens as they're being generated
streamer = TextStreamer(tokenizer)
generated_tokens = model.generate(
    input_ids,
    generation_config=GenerationConfig(
        # use_cache=True is required, the rest can be changed up.
        use_cache=True,
        min_new_tokens=20000,
        max_new_tokens=50000,
        penalty_alpha=0.6,
        top_k=5,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
    ),
    streamer=streamer,
)
# Decode the final generated text
output_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
sparverius commented 10 months ago

Confirmed working, even tested with a few gptq models! just needed to git+ install

pip install git+https://github.com/tomaarsen/attention_sinks.git
tomaarsen commented 10 months ago

That's awesome! I'm preparing a release now so the install is a bit easier - I'm just doing some edits on the README and CHANGELOG first :)

Thanks for helping with testing!

tomaarsen commented 10 months ago

v0.2.2 has been released, which includes this PR.