tomaarsen / attention_sinks

Extend existing LLMs way beyond the original training length with constant memory usage, without retraining
https://huggingface.co/blog/tomaarsen/attention-sinks
Apache License 2.0
649 stars 41 forks source link

Trying a minimal example with LlamaForCasualLM, sadly it fails #1

Closed alexbalandi closed 9 months ago

alexbalandi commented 9 months ago

My minimal example:

import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
from transformers import AutoTokenizer, StoppingCriteria, StoppingCriteriaList
repo = "meta-llama/Llama-2-13b-chat-hf"
tokenizer = AutoTokenizer.from_pretrained(repo)
from attention_sinks import LlamaForCausalLM 
model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="auto", load_in_4bit=True)
# Set the text you want to generate text based on

#text = "<s> you are hepful assistant. </s> <u> Tell me the pros and cons of coffee. Two points. </u>"
text = "<s> you are hepful assistant. </s> <u> Write me a long essay on the reasons for fall of roman empire/u>"

# Encode the text
input_ids = tokenizer.encode(text, return_tensors='pt').to(device)

# Generate text
generated_tokens = model.generate(input_ids, penalty_alpha=0.6, top_k=5, max_length=4096)
# Decode the generated text
generated_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)

print(generated_text)

Fails here:

File [~/mambaforge/envs/data_science/lib/python3.10/site-packages/attention_sinks/models/llama/pos_shift.py:103](https://file+.vscode-resource.vscode-cdn.net/home/alexbalandi/betterwithai/personalized_assistant/notebooks/~/mambaforge/envs/data_science/lib/python3.10/site-packages/attention_sinks/models/llama/pos_shift.py:103), in llama_pos_shift_attention_forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache)
    101 if attention_mask is not None:
    102     if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
--> 103         raise ValueError(
    104             f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
    105         )
    106     attn_weights = attn_weights + attention_mask
    108 # upcast attention to fp32

ValueError: Attention mask should be of size (1, 1, 1, 1025), but is torch.Size([1, 1, 1, 1026])

The root of issue is clear, but trying dumb fixes (like slicing the attention mask to make it "fit") doesn't work. Is it at least reproducable in your env? :eyes: I'd really appreciate any pointers on ways to fix this :pray:

alexbalandi commented 9 months ago

Note : removing penalty_alpha=0.6, top_k=5 (that activate contrastive search) has no effect on the problem, it reproduces all the same.

tomaarsen commented 9 months ago

Let me look into this! I haven't tried to generate myself: i've only tried to directly call forward on the LlamaModel/FalconModel in my benchmarks.

tomaarsen commented 9 months ago

I also get failures when calling generate, although the model does work if I do the generation manually like so:

with torch.no_grad():
    past_key_values = None
    for i in range(4096):
        input_ids.to(model.device)
        outputs = model(input_ids, past_key_values=past_key_values, use_cache=True)
        logits = outputs.logits.view(-1, model.config.vocab_size)
        past_key_values = outputs.past_key_values
        token = logits[-1,:].argmax()
        print(i, tokenizer.decode(token, clean_up_tokenization_spaces=False))
        input_ids = token.unsqueeze(0).unsqueeze(0)

It ends up writing <u> Write me a long essay on the reasons for fall of roman empire/u> over and over for thousands of tokens (because this is not a instruct-tuned model, this is how the pure transformers model reacts too).

I'll also check what happens if I use the windowed attention approach, i.e. the green line here. Edit: See these outputs. The left is the index and the right is the output token. It completely loses the plot.

996 Write
997 O
998 O
999 u
1000 /
1001 u
1002 /
1003 u
1004 /
1005 u
1006 /
1007 /
1008 /
1009 /
1010 u
1011 /
1012 /
1013 /
1014 /
1015 /
...
1704 O
1705 O
1706 O
1707 O
1708 in
1709 .
1710 Ћ
1711 nobody
1712 nobody
1713 nobody
1714 nobody
1715 nobody
1716 nobody

So, attention_sinks does work, but not with model.generate at the moment. I'll have to debug the generate method to figure out where the issue originates.

alexbalandi commented 9 months ago

Thank you for noticing the model, I meant to use chat model, but used the original one. With repo = "meta-llama/Llama-2-7b-chat-hf" and all things tied to it and your code to generate, the output certainly looks decent, memory is stable as avertised!

but ye, I'll look myself as well into how to either fix "generate" or just reimplement the contrastive search.

tomaarsen commented 9 months ago

Yeah it makes sense to use the chat model, my internet is just kind of slow so it would've taken another 25 minutes just to get it downloaded haha. Any help is certainly welcome.

fblgit commented 9 months ago

On the script benchmark/perplexity.py I added on L70:

                token = logits[-1,:].argmax()
                print(idx, tokenizer.decode(token, clean_up_tokenization_spaces=False))

The result shows:

8144 empt
nll:  0.01, ppl:     1.01:
8145 ory
nll:  0.00, ppl:     1.00:
8146 order
nll:  1.45, ppl:     4.28:
8147 of
nll:  1.25, ppl:     3.48:
8148 the
nll:  0.41, ppl:     1.51:
8149 sup
nll:  0.85, ppl:     2.34:
8150 reme
nll:  0.00, ppl:     1.00:
8151 court
nll:  0.01, ppl:     1.01:
8152 of
nll:  1.97, ppl:     7.17:
8153 ing
nll:  0.00, ppl:     1.00:
8154 the
nll:  4.88, ppl:   131.69:
8155 the
nll:  3.53, ppl:    33.97:

It doesn't generate, as mentioned before. Keep in mind when you manipulate the attention mechanism.. perplexity & loss may "apparently" improve as result of the over-confidence caused by the attention.

IMHO is crucial to pass a real evaluation test to this..such as MMLU and compare the results.

tomaarsen commented 9 months ago

I agree completely. We need more thorough experiments on real benchmarks to get a feel of if the performance is good at longer input lengths.

tomaarsen commented 9 months ago

Feel free to experiment with #6 to get model.generate working.

tomaarsen commented 9 months ago

Please try the following snippet with the model of your choice and a corresponding prompt. The tokenizer here is set up to endlessly generate, so it may still eventually lose track of what it was doing, but it shouldn't forget English like what would happen with pure transformers or windowed attention.

import torch
from transformers import AutoTokenizer, TextStreamer, GenerationConfig
from attention_sinks import AutoModelForCausalLM

# model_id = "meta-llama/Llama-2-7b-hf"
# model_id = "mistralai/Mistral-7B-v0.1"
model_id = "mosaicml/mpt-7b"
# model_id = "tiiuae/falcon-7b"
# model_id = "EleutherAI/pythia-6.9b-deduped"

# Load the chosen model and corresponding tokenizer
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    # for efficiency:
    device_map="auto",
    torch_dtype=torch.float16,
    # `attention_sinks`-specific arguments:
    attention_sink_size=4,
    attention_sink_window_size=252,
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token_id = tokenizer.eos_token_id

# Our input text
text = "Vaswani et al. (2017) introduced the Transformers"

# Encode the text
input_ids = tokenizer.encode(text, return_tensors="pt").to(model.device)

# Print tokens as they're being generated
streamer = TextStreamer(tokenizer)
generated_tokens = model.generate(
    input_ids,
    generation_config=GenerationConfig(
        # use_cache=True is required, the rest can be changed up.
        use_cache=True,
        min_new_tokens=20000,
        max_new_tokens=50000,
        penalty_alpha=0.6,
        top_k=5,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
    ),
    streamer=streamer,
)
# Decode the final generated text
output_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
fblgit commented 9 months ago

I tested it, he can produce long outputs.. tons of nonsense output at some point. I would say that the </s> has an effect of change context.

In any case, there is one thing that still needs some change inside transformers library. Check the use_cache implementation inside the attention mechanism, I think the cache doesnt prevent some computation from happening despite being discarded later ?

To @Guangxuan-Xiao the paper denotes the first tokens importance, I have observed this across plots of the attention. I think this is because the first token is always present so the attn learns to use that as a fixed axis somehow. You can observe this more pronounced within convolutional attentions like GPT2. IMHO, this can be exploited somehow using those tokens to hold a larger truth and use the attention heads to pass this forward.. a sink_head ?

tomaarsen commented 9 months ago

I don't have much time to address the rest of your comment, but I wanted to point you to this streaming demo that I just made: https://github.com/tomaarsen/attention_sinks/blob/main/demo/streaming.py

This is the primary use case of attention_sinks I think. It's doing repeated prompts as if someone is prompting a chat assistant hundreds of time sequentially. I haven't been able to test this much, but the idea from the paper is that with attention sinks, the model remains capable to respond to the prompts even after having read hundreds of thousands of messages prior.

It's a bit of a better example than the completely endless nonsense generation.

alexbalandi commented 9 months ago

Just chiming in here to say thank you for all your hard work that makes it easier to experiment with the results of the paper, you rock :hugs:

tomaarsen commented 9 months ago

Gladly! 😄

EGjoni commented 9 months ago

16 causes this to happen again for llama2 7b models.

tomaarsen commented 9 months ago

Thanks for reporting! That was an oversight, I'll resolve it soon. Until then, you can use an earlier commit or the latest release on PyPI.

tomaarsen commented 9 months ago

Resolved in 48bb293d4fb15d08bdeb3a0425cee0ea78f8ba52, thanks again for reporting