huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
128.75k stars 25.54k forks source link

Sink Cache Attention Scores are strange. CausalMask seems not working. #30926

Closed Tomorrowdawn closed 8 hours ago

Tomorrowdawn commented 1 month ago

System Info

Who can help?

No response

Information

Tasks

Reproduction

dataset:

from datasets import load_dataset
dataset = load_dataset("./xsum", split='test[:1000]')# I load it from local due to unsolved utf8 error. It is exactly XSUM

I concatenate all 'document' text to generate a streaming task(test streamingLLM). The code is trivial but long so omitted.

model: LlamaForCasualLM Weight: llama2-7b-hf

core run code:

def test_run(run_name, sink_len, win_len, stream, 
             break_len = 8000, get_attn_steps = [10]):
    kvcache = SinkCache(win_len, sink_len)
    perp_vs_len = []
    attn_scores_all = {}
    with torch.inference_mode():
        for i, input_ids in enumerate(stream):
            get_attn = (i in get_attn_steps)
            if i * stream.input_len > break_len:
                break
            output = model(
                input_ids.to(device),
                past_key_values=kvcache,
                use_cache=True,
                output_attentions=get_attn,
                return_dict=True
            )#type: ignore
            perp = perplexity(input_ids, output.logits).unsqueeze(0).item()
            perp_vs_len.append(((i+1)*input_ids.shape[1], perp))
            if get_attn:
                attn_scores = output.attentions
                attn_scores = {
                    "layer_0_head_0":attn_scores[0][0,0,:,:],
                    "layer_2_head_0":attn_scores[2][0,0,:,:],
                    "layer_5_head_0":attn_scores[5][0,0,:,:],
                    "layer_10_head_0":attn_scores[10][0,0,:,:],
                    "layer_-1_head_0":attn_scores[-1][0,0,:,:]
                }
                attn_scores_all[i] = attn_scores
    result = {
        'sink_len': sink_len,
        'win_len': win_len,
        'perp_vs_len': perp_vs_len,
        'attn_scores': attn_scores_all
    }
    torch.save(result, f"./result/{run_name}_sink{sink_len}perp_attn.pt")
    return result

stream object produces 100 tokens per iter, like a list(containinng many tokens).

And I plotted the attention scores. However, the upper triangle part of them are not zeros.

sink num = 0(local window):

image

sink num = 16(for streaming LLM):

image

Expected behavior

The attention scores matrix for prompt len = 0(no kv cache) is right:

image

Tomorrowdawn commented 1 month ago

Supplementary:

It works with DynamicCache.

image

So it must be something wrong with SinkCache and relevant control code.

amyeroberts commented 1 month ago

cc @gante @ArthurZucker

ArthurZucker commented 1 month ago

Have not worked on the sink cache so will let @gante answer here!

deadpool66 commented 1 month ago

In cache_utils.py, I noticed that keys_to_keep = self.key_cache[layer_idx][ :, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] : ] might go wrong when -self.window_length + self.num_sink_tokens + key_states.shape[-2] >= 0 Not sure is it relevant

tomaarsen commented 1 month ago

It's been a bit since I worked on this, but I think that -self.window_length + self.num_sink_tokens + key_states.shape[-2] >= 0 is not really possible.

In the code here: https://github.com/huggingface/transformers/blob/b72752f06830cb6cf8d21c284f68e15faa100c4d/src/transformers/cache_utils.py#L703-L706

We're in the "Shifting cache" phase, i.e. the cache already exists, and now we're adding enough tokens to make it overflow. However, if it already exists, then I think (I'm not 100% on this) we always add 1 new generated token, i.e. key_states.shape[-2] is 1. So I think a non-negative value can only happen if the num_sink_tokens >= window_length - 1, which is not normal behaviour.

However, if it's somehow possible to, when the cache already exists, add a bunch of tokens in one go, then I think it would be possible to mess this up. Then, the keys_to_keep should really be empty (as we're skipping way ahead and keeping no tokens), but the overflow of -self.window_length + self.num_sink_tokens + key_states.shape[-2] >= 0 into the positives is allowing some keys to stay. Then the new tokens will get appended and we'll accidentally get a cache that's too large here: https://github.com/huggingface/transformers/blob/b72752f06830cb6cf8d21c284f68e15faa100c4d/src/transformers/cache_utils.py#L724

But I think that should probably cause a pretty easy-to-spot crash as the cache is now bigger than the window size, which should not be possible.

github-actions[bot] commented 1 week ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.