huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
129.29k stars 25.63k forks source link

Problems when using SinkCache for model.generate() #31381

Open Zoeyyao27 opened 1 month ago

Zoeyyao27 commented 1 month ago

System Info

Who can help?

@zucchini-nlp @gante @tomaarse

Information

Tasks

Reproduction

I am trying to use sinkcache for multi-turn dialog where the cache should contains the previous turn's dialog. Here is my code:

import torch

from transformers import SinkCache

from transformers import AutoConfig, LlamaForCausalLM, AutoTokenizer

model_id = "meta-llama/Llama-2-7b-hf"
model = LlamaForCausalLM.from_pretrained(
      model_id, low_cpu_mem_usage=True, device_map='auto',
      torch_dtype=torch.bfloat16,cache_dir="cache")
model = model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_id,cache_dir="cache")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'

device = model.device

cache = SinkCache(window_length=1024, num_sink_tokens=4)
prefix_list=["hello,my name is yy","what is my name?"]
for prefix in prefix_list:
    inputs = tokenizer(prefix, return_tensors='pt').to(device)

    gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=64,
                            use_cache=True,
                            past_key_values=cache,
                            pad_token_id=tokenizer.pad_token_id)

    decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)

    print(decoded)

I am not sure if I was doing it right. But I got the following error:

Traceback (most recent call last):
  File "/data/yaoy/long_context/repeat_sirllm/main.py", line 25, in <module>
    gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=64,
  File "/data/yaoy/anaconda/envs/long/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/data/yaoy/anaconda/envs/long/lib/python3.10/site-packages/transformers/generation/utils.py", line 1758, in generate
    result = self._sample(
  File "/data/yaoy/anaconda/envs/long/lib/python3.10/site-packages/transformers/generation/utils.py", line 2390, in _sample
    model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
  File "/data/yaoy/anaconda/envs/long/lib/python3.10/site-packages/transformers/generation/utils.py", line 1326, in _get_initial_cache_position
    model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device)
RuntimeError: upper bound and larger bound inconsistent with step sign

Expected behavior

Cache should include information from previous conversations.

zucchini-nlp commented 1 month ago

@Zoeyyao27 hey! This is occurring because you have to pass in the prev generated text into generate() along with cache. It will not be used to calculate key-values, but we need that to infer actual seq length and build correct attention_mask. I modified slightly your code, see below

prefix_list = ["Hello, my name is yy", "What is your name?"]
dialog_history = prefix_list[0]
cache = SinkCache(window_length=1024, num_sink_tokens=4)

for prefix in prefix_list:
    dialog_history += prefix
    inputs = tokenizer(dialog_history, return_tensors='pt').to(device)
    input_length = inputs.input_ids.shape[-1]

    gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=64,
                            use_cache=True,
                            past_key_values=cache,
                            pad_token_id=tokenizer.pad_token_id,
                            return_dict_in_generate=True,
                        )

    decoded = tokenizer.decode(gen_out.sequences[0, input_length:], skip_special_tokens=True)
    dialog_history += decoded
    cache = gen_out.past_key_values

    print(decoded)
gante commented 1 month ago

A note adding to @zucchini-nlp's comment above: the line cache = gen_out.past_key_values is not needed. The cache object is updated in-place, the only operation you need to do manually is to instantiate a new cache for a brand new chat/prompt :)

Zoeyyao27 commented 4 weeks ago

Thank you for your reply!

However when I use chat model and use tokenizer.apply_chat_template, I would get the following error:

Traceback (most recent call last):
  File "/data/yaoy/long_context/repeat_sirllm/main.py", line 35, in <module>
    gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=64,
  File "/data/yaoy/anaconda/envs/long/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/data/yaoy/long_context/repeat_sirllm/transformers/src/transformers/generation/utils.py", line 1896, in generate
    result = self._sample(
  File "/data/yaoy/long_context/repeat_sirllm/transformers/src/transformers/generation/utils.py", line 2633, in _sample
    outputs = self(
  File "/data/yaoy/anaconda/envs/long/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/yaoy/anaconda/envs/long/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/yaoy/anaconda/envs/long/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/data/yaoy/long_context/repeat_sirllm/transformers/src/transformers/models/llama/modeling_llama.py", line 1162, in forward
    outputs = self.model(
  File "/data/yaoy/anaconda/envs/long/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/yaoy/anaconda/envs/long/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/yaoy/long_context/repeat_sirllm/transformers/src/transformers/models/llama/modeling_llama.py", line 938, in forward
    causal_mask = self._update_causal_mask(
  File "/data/yaoy/long_context/repeat_sirllm/transformers/src/transformers/models/llama/modeling_llama.py", line 1060, in _update_causal_mask
    causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
RuntimeError: The size of tensor a (68) must match the size of tensor b (50) at non-singleton dimension 0

Here is the code:

import torch

from transformers import AutoTokenizer,AutoModelForCausalLM
from transformers.cache_utils import SinkCache

model_id = "01-ai/Yi-1.5-6B-Chat"
model = AutoModelForCausalLM.from_pretrained(
      model_id, low_cpu_mem_usage=True, device_map='auto',
      torch_dtype=torch.bfloat16,cache_dir="cache")
model = model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_id,cache_dir="cache",device_map='auto')
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'

device = model.device

prefix_list = ["Hello, my name is yy", "What is your name?"]

dialog_history = []
cache = SinkCache(window_length=50, num_sink_tokens=4)#,recent_ratio=0.3) #1024

for prefix in prefix_list:
    dialog_history.append({"role": "user", "content": prefix})
    input_text=tokenizer.apply_chat_template(dialog_history, tokenize=False)

    inputs = tokenizer(input_text, return_tensors='pt').to(device)
    input_length = inputs.input_ids.shape[-1]

    gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=64,
                            use_cache=True,
                            past_key_values=cache,
                            pad_token_id=tokenizer.pad_token_id,
                            return_dict_in_generate=True,
                        )

    decoded = tokenizer.decode(gen_out.sequences[0, input_length:], skip_special_tokens=True)
    dialog_history.append({"role": "assistant", "content": decoded})
    print(decoded)
    print(cache)
Zoeyyao27 commented 4 weeks ago

Thank you for your reply!

However when I use chat model and use tokenizer.apply_chat_template, I would get the following error:

Traceback (most recent call last):
  File "/data/yaoy/long_context/repeat_sirllm/main.py", line 35, in <module>
    gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=64,
  File "/data/yaoy/anaconda/envs/long/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/data/yaoy/long_context/repeat_sirllm/transformers/src/transformers/generation/utils.py", line 1896, in generate
    result = self._sample(
  File "/data/yaoy/long_context/repeat_sirllm/transformers/src/transformers/generation/utils.py", line 2633, in _sample
    outputs = self(
  File "/data/yaoy/anaconda/envs/long/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/yaoy/anaconda/envs/long/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/yaoy/anaconda/envs/long/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/data/yaoy/long_context/repeat_sirllm/transformers/src/transformers/models/llama/modeling_llama.py", line 1162, in forward
    outputs = self.model(
  File "/data/yaoy/anaconda/envs/long/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/yaoy/anaconda/envs/long/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/yaoy/long_context/repeat_sirllm/transformers/src/transformers/models/llama/modeling_llama.py", line 938, in forward
    causal_mask = self._update_causal_mask(
  File "/data/yaoy/long_context/repeat_sirllm/transformers/src/transformers/models/llama/modeling_llama.py", line 1060, in _update_causal_mask
    causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
RuntimeError: The size of tensor a (68) must match the size of tensor b (50) at non-singleton dimension 0

Here is the code:

import torch

from transformers import AutoTokenizer,AutoModelForCausalLM
from transformers.cache_utils import SinkCache

model_id = "01-ai/Yi-1.5-6B-Chat"
model = AutoModelForCausalLM.from_pretrained(
      model_id, low_cpu_mem_usage=True, device_map='auto',
      torch_dtype=torch.bfloat16,cache_dir="cache")
model = model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_id,cache_dir="cache",device_map='auto')
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'

device = model.device

prefix_list = ["Hello, my name is yy", "What is your name?"]

dialog_history = []
cache = SinkCache(window_length=50, num_sink_tokens=4)#,recent_ratio=0.3) #1024

for prefix in prefix_list:
    dialog_history.append({"role": "user", "content": prefix})
    input_text=tokenizer.apply_chat_template(dialog_history, tokenize=False)

    inputs = tokenizer(input_text, return_tensors='pt').to(device)
    input_length = inputs.input_ids.shape[-1]

    gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=64,
                            use_cache=True,
                            past_key_values=cache,
                            pad_token_id=tokenizer.pad_token_id,
                            return_dict_in_generate=True,
                        )

    decoded = tokenizer.decode(gen_out.sequences[0, input_length:], skip_special_tokens=True)
    dialog_history.append({"role": "assistant", "content": decoded})
    print(decoded)
    print(cache)

In fact, I don't think the apply_chat_template cause the problem. If I use a smaller window_length in SinkCache, I would get the same error. Here is the code:

import torch

from transformers import AutoTokenizer,AutoModelForCausalLM
from transformers.cache_utils import SinkCache

model_id = "meta-llama/Llama-2-7b-hf"
model = AutoModelForCausalLM.from_pretrained(
      model_id, low_cpu_mem_usage=True, device_map='auto',
      torch_dtype=torch.bfloat16,cache_dir="cache")
model = model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_id,cache_dir="cache",device_map='auto')
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'

device = model.device

prefix_list = ["Hello, my name is yy", "What is your name?"]
dialog_history = prefix_list[0]
cache = SinkCache(window_length=20, num_sink_tokens=4)

for prefix in prefix_list:
    dialog_history += prefix
    inputs = tokenizer(dialog_history, return_tensors='pt').to(device)
    input_length = inputs.input_ids.shape[-1]

    gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=64,
                            use_cache=True,
                            past_key_values=cache,
                            pad_token_id=tokenizer.pad_token_id,
                            return_dict_in_generate=True,
                        )

    decoded = tokenizer.decode(gen_out.sequences[0, input_length:], skip_special_tokens=True)
    dialog_history += decoded
    cache = gen_out.past_key_values

    print(decoded)
zucchini-nlp commented 3 weeks ago

Hmm, right, there's a bug in how we crop input_ids when continuing generation from SinkCache. @gante will you fix it or I can open a PR later this week.

Yet, I'm not sure how it will git in the current API, we prob need to update caching API soon and make more rigorous tests for all cache types, WDYT?

gante commented 3 weeks ago

@zucchini-nlp please have a go at it :)

NiftyliuS commented 11 hours ago

While you are at it there is a corresponding bug when using tokens pre-loading

model = AutoModelForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)

def preload_tokens(self, tokens, attention_mask, sink_cache=None):
        with torch.inference_mode():
            self.model(
                input_ids=tokens,
                attention_mask=attention_mask,
                use_cache=True,
                past_key_values=sink_cache
            )

cache = SinkCache(num_sink_tokens=4, window_length=model.config.max_position_embeddings)       

text = "Hello there!"
encoded_dict = tokenizer.encode_plus(
    prompt,
    add_special_tokens=True,
    return_attention_mask=True,
    return_tensors='pt'
)

preload_tokens(
    tokens=encoded_dict["input_ids"],
    attention_mask=encoded_dict["attention_mask"],
    sink_cache=cache
)

# So far so good and everything works as expected.

text = "Hello there! General Kenobi!"
encoded_dict_extended = tokenizer.encode_plus(
    prompt,
    add_special_tokens=True,
    return_attention_mask=True,
    return_tensors='pt'
)

preload_tokens(
    tokens=encoded_dict_extended ["input_ids"]
    attention_mask=encoded_dict_extended ["attention_mask"],
    sink_cache=cache
)

RuntimeError('The expanded size of the tensor (12) must match the existing size (8) at non-singleton dimension 3. Target sizes: [1, 32, 8, 12]. Tensor sizes: [1, 1, 8, 8]

The interesting thing is that the initial text is 4 tokens long while the extended is 8 tokens. There might be an addition bug somewhere since it presents similar behaviour with different texts ( adds initial input length to extended input length )

NiftyliuS commented 10 hours ago

Update: Not passing attention mask fixes the issue. I am not a fan of this fix but hey,,, if it works it works.

PS: Dynamic cache has the same issue

zucchini-nlp commented 2 hours ago

@NiftyliuS hey!

When we use a cache object in the input, the Attention module concatenates the current key-values and the past key-values (from cache) to compute attention scores. That means our attn matrix will be of shape (new_ids_length, past_kv_length + new_ids_length). Therefore if we're using a cache object to iteratively call forward (instead of calling generate()) we have to make sure the shape of attention mask is (batch_size, past_kv_length + new_ids_length)

In you case, you either

Hope this clarifies your question! Btw, it's not the first time people are confused about using a cache object outside generate(), so I am considering making a better doc on that :)