huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
131.17k stars 26.07k forks source link

Keep Tuple of past key values as an option #31962

Open jiqing-feng opened 1 month ago

jiqing-feng commented 1 month ago

Feature request

I see llama will remove tuple past key values in 4.43.

Motivation

The model outputs should always keep tuple type as an option. If we remove this, all models will be failed in jit trace.

Your contribution

@gante Do you mind taking a look at it? Thx!

cc @amyeroberts

jiqing-feng commented 1 month ago

The jit trace is a core optimization in optimum-intel, this change will break all use cases in optimum-intel.

cc @echarlaix

jiqing-feng commented 1 month ago

For now, jit trace can still work because to_legacy_cache will convert DynamicCache to Tuple, so I open this issue just want to make sure we will not eliminate the return_legacy_cache parameter. Thx!

And for the warning, if we use return_legacy_cache to convert the past_key_values to a tuple, then the next round of past_key_values must be a tuple. So, the inputs must accept past_key_values as a tuple.

gante commented 1 month ago

Hi @jiqing-feng @echarlaix 👋

Let me start off with the following: if it is breaking for optimum-intel, we will keep it :)


That being said, it would be interesting for both parties if removing the Tuple support is achievable! For us, it would mean less code to maintain. For optimum-intel, it would mean potential support for all new Cache classes, including quantized caches.

The default Cache, DynamicCache, is not jit-friendly, as it is a tensor with dynamic shapes. Have you tried using StaticCache, our compilation-friendly cache? Something like this: (this is a generate example, the API for forward is the same)

from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache
import torch
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"  # To prevent long warnings :)

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
model = AutoModelForCausalLM.from_pretrained("google/gemma-2b", device_map="auto")

model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
input_text = "The theory of special relativity states "
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
prompt_length = input_ids.input_ids.shape[1]
model.generation_config.max_new_tokens = 16

past_key_values = StaticCache(
    config=model.config,
    max_batch_size=1,
    # If you plan to reuse the cache, make sure the cache length is large enough for all cases
    max_cache_len=prompt_length+model.generation_config.max_new_tokens,
    device=model.device,
    dtype=model.dtype
)
outputs = model.generate(**input_ids, past_key_values=past_key_values)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))

In the specific case of generate, the usage can be further simplified to

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"  # To prevent long warnings :)

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
model = AutoModelForCausalLM.from_pretrained("google/gemma-2b", device_map="auto")

model.generation_config.cache_implementation = "static"

model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
input_text = "The theory of special relativity states "
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")

outputs = model.generate(**input_ids)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
jiqing-feng commented 1 month ago

Hi @jiqing-feng @echarlaix 👋

Let me start off with the following: if it is breaking for optimum-intel, we will keep it :)

That being said, it would be interesting for both parties if removing the Tuple support is achievable! For us, it would mean less code to maintain. For optimum-intel, it would mean potential support for all new Cache classes, including quantized caches.

The default Cache, DynamicCache, is not jit-friendly, as it is a tensor with dynamic shapes. Have you tried using StaticCache, our compilation-friendly cache? Something like this: (this is a generate example, the API for forward is the same)

from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache
import torch
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"  # To prevent long warnings :)

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
model = AutoModelForCausalLM.from_pretrained("google/gemma-2b", device_map="auto")

model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
input_text = "The theory of special relativity states "
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
prompt_length = input_ids.input_ids.shape[1]
model.generation_config.max_new_tokens = 16

past_key_values = StaticCache(
    config=model.config,
    max_batch_size=1,
    # If you plan to reuse the cache, make sure the cache length is large enough for all cases
    max_cache_len=prompt_length+model.generation_config.max_new_tokens,
    device=model.device,
    dtype=model.dtype
)
outputs = model.generate(**input_ids, past_key_values=past_key_values)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))

In the specific case of generate, the usage can be further simplified to

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"  # To prevent long warnings :)

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
model = AutoModelForCausalLM.from_pretrained("google/gemma-2b", device_map="auto")

model.generation_config.cache_implementation = "static"

model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
input_text = "The theory of special relativity states "
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")

outputs = model.generate(**input_ids)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))

Hi @gante . Thanks for your support. We are talking about jit.trace, not torch.compile. We still want to keep tuple past_key_values, not because it's a fixed shape, it's because jit.trace only accepts tuple or dict inputs. Custom classes like Cache cannot be recognized by jit. Our usage in optimum-intel is more like:

from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache
import torch
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"  # To prevent long warnings :)

tokenizer = AutoTokenizer.from_pretrained("Felladrin/Llama-68M-Chat-v1")
model = AutoModelForCausalLM.from_pretrained("Felladrin/Llama-68M-Chat-v1")

input_text = "The theory of special relativity states "
inputs = tokenizer(input_text, return_tensors="pt")
prompt_length = inputs.input_ids.shape[1]
model.generation_config.max_new_tokens = 16

past_key_values = StaticCache(
    config=model.config,
    max_batch_size=1,
    # If you plan to reuse the cache, make sure the cache length is large enough for all cases
    max_cache_len=prompt_length+model.generation_config.max_new_tokens,
    device=model.device,
    dtype=model.dtype
)
inputs = inputs.data
inputs["past_key_values"] = past_key_values

trace_model = torch.jit.trace(model, example_kwarg_inputs=inputs, strict=False)

traceback: image

gante commented 1 month ago

it's because jit.trace only accepts tuple or dict inputs.

😭 Thank you for clarifying; I wasn't aware of this limitation.

What if we add a pre- and post-forward hook in optimum-intel, such that the model always converts from/to the legacy cache format there? In other words, transformers would remove the conversion bits in forward, but optimum-intel's models would always run them.

Do you think this could work? I'm happy to work on it myself :)

jiqing-feng commented 1 month ago

it's because jit.trace only accepts tuple or dict inputs.

😭 Thank you for clarifying; I wasn't aware of this limitation.

What if we add a pre- and post-forward hook in optimum-intel, such that the model always converts from/to the legacy cache format there? In other words, transformers would remove the conversion bits in forward, but optimum-intel's models would always run them.

Do you think this could work? I'm happy to work on it myself :)

I am not sure about that because we have our past key values, which have a different shape compared to the standard cache. Do you plan to remove tuples for all models? I see that only some models like llama have this plan, and other models like gpt2 and opt still keep the tuple format past key values.

gante commented 1 month ago

The long-term plan is to shift all models to the new Cache classes. We're focusing on new models first 🤗

DynamicCache and the old tuple format should have the exact same shapes!

ArthurZucker commented 1 month ago

Also cc @echarlaix here

ArthurZucker commented 1 month ago

We talked offline and optimum will need an update indeed

brucewlee commented 4 weeks ago

Can we always keep past_key_values as an option?

If the plan is to shift towards removing past_key_values, I was wondering if there were any examples that I could see?

gante commented 4 weeks ago

@brucewlee the cache class that is now being used by default on modern model architectures supports seamless conversion to/from the legacy cache types 🤗 All a custom script needs to do is to convert from/to the legacy format before/after the model forward pass

See the class reference here