Closed alexbalandi closed 9 months ago
Note : removing penalty_alpha=0.6, top_k=5
(that activate contrastive search) has no effect on the problem, it reproduces all the same.
Let me look into this! I haven't tried to generate myself: i've only tried to directly call forward
on the LlamaModel
/FalconModel
in my benchmarks.
I also get failures when calling generate
, although the model does work if I do the generation manually like so:
with torch.no_grad():
past_key_values = None
for i in range(4096):
input_ids.to(model.device)
outputs = model(input_ids, past_key_values=past_key_values, use_cache=True)
logits = outputs.logits.view(-1, model.config.vocab_size)
past_key_values = outputs.past_key_values
token = logits[-1,:].argmax()
print(i, tokenizer.decode(token, clean_up_tokenization_spaces=False))
input_ids = token.unsqueeze(0).unsqueeze(0)
It ends up writing <u> Write me a long essay on the reasons for fall of roman empire/u>
over and over for thousands of tokens (because this is not a instruct-tuned model, this is how the pure transformers model reacts too).
I'll also check what happens if I use the windowed attention approach, i.e. the green line here. Edit: See these outputs. The left is the index and the right is the output token. It completely loses the plot.
996 Write
997 O
998 O
999 u
1000 /
1001 u
1002 /
1003 u
1004 /
1005 u
1006 /
1007 /
1008 /
1009 /
1010 u
1011 /
1012 /
1013 /
1014 /
1015 /
...
1704 O
1705 O
1706 O
1707 O
1708 in
1709 .
1710 Ћ
1711 nobody
1712 nobody
1713 nobody
1714 nobody
1715 nobody
1716 nobody
So, attention_sinks
does work, but not with model.generate
at the moment. I'll have to debug the generate
method to figure out where the issue originates.
Thank you for noticing the model, I meant to use chat model, but used the original one. With repo = "meta-llama/Llama-2-7b-chat-hf"
and all things tied to it and your code to generate, the output certainly looks decent, memory is stable as avertised!
but ye, I'll look myself as well into how to either fix "generate" or just reimplement the contrastive search.
Yeah it makes sense to use the chat model, my internet is just kind of slow so it would've taken another 25 minutes just to get it downloaded haha. Any help is certainly welcome.
On the script benchmark/perplexity.py
I added on L70:
token = logits[-1,:].argmax()
print(idx, tokenizer.decode(token, clean_up_tokenization_spaces=False))
The result shows:
8144 empt
nll: 0.01, ppl: 1.01:
8145 ory
nll: 0.00, ppl: 1.00:
8146 order
nll: 1.45, ppl: 4.28:
8147 of
nll: 1.25, ppl: 3.48:
8148 the
nll: 0.41, ppl: 1.51:
8149 sup
nll: 0.85, ppl: 2.34:
8150 reme
nll: 0.00, ppl: 1.00:
8151 court
nll: 0.01, ppl: 1.01:
8152 of
nll: 1.97, ppl: 7.17:
8153 ing
nll: 0.00, ppl: 1.00:
8154 the
nll: 4.88, ppl: 131.69:
8155 the
nll: 3.53, ppl: 33.97:
It doesn't generate, as mentioned before. Keep in mind when you manipulate the attention mechanism.. perplexity & loss may "apparently" improve as result of the over-confidence caused by the attention.
IMHO is crucial to pass a real evaluation test to this..such as MMLU and compare the results.
I agree completely. We need more thorough experiments on real benchmarks to get a feel of if the performance is good at longer input lengths.
Feel free to experiment with #6 to get model.generate working.
Please try the following snippet with the model of your choice and a corresponding prompt. The tokenizer here is set up to endlessly generate, so it may still eventually lose track of what it was doing, but it shouldn't forget English like what would happen with pure transformers
or windowed attention.
import torch
from transformers import AutoTokenizer, TextStreamer, GenerationConfig
from attention_sinks import AutoModelForCausalLM
# model_id = "meta-llama/Llama-2-7b-hf"
# model_id = "mistralai/Mistral-7B-v0.1"
model_id = "mosaicml/mpt-7b"
# model_id = "tiiuae/falcon-7b"
# model_id = "EleutherAI/pythia-6.9b-deduped"
# Load the chosen model and corresponding tokenizer
model = AutoModelForCausalLM.from_pretrained(
model_id,
# for efficiency:
device_map="auto",
torch_dtype=torch.float16,
# `attention_sinks`-specific arguments:
attention_sink_size=4,
attention_sink_window_size=252,
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token_id = tokenizer.eos_token_id
# Our input text
text = "Vaswani et al. (2017) introduced the Transformers"
# Encode the text
input_ids = tokenizer.encode(text, return_tensors="pt").to(model.device)
# Print tokens as they're being generated
streamer = TextStreamer(tokenizer)
generated_tokens = model.generate(
input_ids,
generation_config=GenerationConfig(
# use_cache=True is required, the rest can be changed up.
use_cache=True,
min_new_tokens=20000,
max_new_tokens=50000,
penalty_alpha=0.6,
top_k=5,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
),
streamer=streamer,
)
# Decode the final generated text
output_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
I tested it, he can produce long outputs.. tons of nonsense output at some point. I would say that the </s>
has an effect of change context.
In any case, there is one thing that still needs some change inside transformers library.
Check the use_cache
implementation inside the attention mechanism, I think the cache doesnt prevent some computation from happening despite being discarded later ?
To @Guangxuan-Xiao the paper denotes the first tokens importance, I have observed this across plots of the attention. I think this is because the first token is always present so the attn learns to use that as a fixed axis somehow. You can observe this more pronounced within convolutional attentions like GPT2. IMHO, this can be exploited somehow using those tokens to hold a larger truth and use the attention heads to pass this forward.. a sink_head
?
I don't have much time to address the rest of your comment, but I wanted to point you to this streaming demo that I just made: https://github.com/tomaarsen/attention_sinks/blob/main/demo/streaming.py
This is the primary use case of attention_sinks
I think. It's doing repeated prompts as if someone is prompting a chat assistant hundreds of time sequentially. I haven't been able to test this much, but the idea from the paper is that with attention sinks, the model remains capable to respond to the prompts even after having read hundreds of thousands of messages prior.
It's a bit of a better example than the completely endless nonsense generation.
Just chiming in here to say thank you for all your hard work that makes it easier to experiment with the results of the paper, you rock :hugs:
Gladly! 😄
Thanks for reporting! That was an oversight, I'll resolve it soon. Until then, you can use an earlier commit or the latest release on PyPI.
Resolved in 48bb293d4fb15d08bdeb3a0425cee0ea78f8ba52, thanks again for reporting
My minimal example:
Fails here:
The root of issue is clear, but trying dumb fixes (like slicing the attention mask to make it "fit") doesn't work. Is it at least reproducable in your env? :eyes: I'd really appreciate any pointers on ways to fix this :pray: