Open jiqing-feng opened 1 month ago
The jit trace is a core optimization in optimum-intel, this change will break all use cases in optimum-intel.
cc @echarlaix
For now, jit trace can still work because to_legacy_cache will convert DynamicCache
to Tuple
, so I open this issue just want to make sure we will not eliminate the return_legacy_cache
parameter. Thx!
And for the warning, if we use return_legacy_cache
to convert the past_key_values
to a tuple, then the next round of past_key_values
must be a tuple. So, the inputs must accept past_key_values
as a tuple.
Hi @jiqing-feng @echarlaix 👋
Let me start off with the following: if it is breaking for optimum-intel
, we will keep it :)
That being said, it would be interesting for both parties if removing the Tuple
support is achievable! For us, it would mean less code to maintain. For optimum-intel
, it would mean potential support for all new Cache
classes, including quantized caches.
The default Cache
, DynamicCache
, is not jit-friendly, as it is a tensor with dynamic shapes. Have you tried using StaticCache
, our compilation-friendly cache? Something like this: (this is a generate
example, the API for forward
is the same)
from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache
import torch
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false" # To prevent long warnings :)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
model = AutoModelForCausalLM.from_pretrained("google/gemma-2b", device_map="auto")
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
input_text = "The theory of special relativity states "
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
prompt_length = input_ids.input_ids.shape[1]
model.generation_config.max_new_tokens = 16
past_key_values = StaticCache(
config=model.config,
max_batch_size=1,
# If you plan to reuse the cache, make sure the cache length is large enough for all cases
max_cache_len=prompt_length+model.generation_config.max_new_tokens,
device=model.device,
dtype=model.dtype
)
outputs = model.generate(**input_ids, past_key_values=past_key_values)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
In the specific case of generate
, the usage can be further simplified to
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false" # To prevent long warnings :)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
model = AutoModelForCausalLM.from_pretrained("google/gemma-2b", device_map="auto")
model.generation_config.cache_implementation = "static"
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
input_text = "The theory of special relativity states "
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
outputs = model.generate(**input_ids)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
Hi @jiqing-feng @echarlaix 👋
Let me start off with the following: if it is breaking for
optimum-intel
, we will keep it :)That being said, it would be interesting for both parties if removing the
Tuple
support is achievable! For us, it would mean less code to maintain. Foroptimum-intel
, it would mean potential support for all newCache
classes, including quantized caches.The default
Cache
,DynamicCache
, is not jit-friendly, as it is a tensor with dynamic shapes. Have you tried usingStaticCache
, our compilation-friendly cache? Something like this: (this is agenerate
example, the API forforward
is the same)from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache import torch import os os.environ["TOKENIZERS_PARALLELISM"] = "false" # To prevent long warnings :) tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b") model = AutoModelForCausalLM.from_pretrained("google/gemma-2b", device_map="auto") model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) input_text = "The theory of special relativity states " input_ids = tokenizer(input_text, return_tensors="pt").to("cuda") prompt_length = input_ids.input_ids.shape[1] model.generation_config.max_new_tokens = 16 past_key_values = StaticCache( config=model.config, max_batch_size=1, # If you plan to reuse the cache, make sure the cache length is large enough for all cases max_cache_len=prompt_length+model.generation_config.max_new_tokens, device=model.device, dtype=model.dtype ) outputs = model.generate(**input_ids, past_key_values=past_key_values) print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
In the specific case of
generate
, the usage can be further simplified tofrom transformers import AutoTokenizer, AutoModelForCausalLM import torch import os os.environ["TOKENIZERS_PARALLELISM"] = "false" # To prevent long warnings :) tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b") model = AutoModelForCausalLM.from_pretrained("google/gemma-2b", device_map="auto") model.generation_config.cache_implementation = "static" model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) input_text = "The theory of special relativity states " input_ids = tokenizer(input_text, return_tensors="pt").to("cuda") outputs = model.generate(**input_ids) print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
Hi @gante . Thanks for your support. We are talking about jit.trace, not torch.compile. We still want to keep tuple past_key_values
, not because it's a fixed shape, it's because jit.trace only accepts tuple or dict inputs. Custom classes like Cache
cannot be recognized by jit.
Our usage in optimum-intel is more like:
from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache
import torch
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false" # To prevent long warnings :)
tokenizer = AutoTokenizer.from_pretrained("Felladrin/Llama-68M-Chat-v1")
model = AutoModelForCausalLM.from_pretrained("Felladrin/Llama-68M-Chat-v1")
input_text = "The theory of special relativity states "
inputs = tokenizer(input_text, return_tensors="pt")
prompt_length = inputs.input_ids.shape[1]
model.generation_config.max_new_tokens = 16
past_key_values = StaticCache(
config=model.config,
max_batch_size=1,
# If you plan to reuse the cache, make sure the cache length is large enough for all cases
max_cache_len=prompt_length+model.generation_config.max_new_tokens,
device=model.device,
dtype=model.dtype
)
inputs = inputs.data
inputs["past_key_values"] = past_key_values
trace_model = torch.jit.trace(model, example_kwarg_inputs=inputs, strict=False)
traceback:
it's because jit.trace only accepts tuple or dict inputs.
😠Thank you for clarifying; I wasn't aware of this limitation.
What if we add a pre- and post-forward hook in optimum-intel
, such that the model always converts from/to the legacy cache format there? In other words, transformers
would remove the conversion bits in forward
, but optimum-intel
's models would always run them.
Do you think this could work? I'm happy to work on it myself :)
it's because jit.trace only accepts tuple or dict inputs.
😠Thank you for clarifying; I wasn't aware of this limitation.
What if we add a pre- and post-forward hook in
optimum-intel
, such that the model always converts from/to the legacy cache format there? In other words,transformers
would remove the conversion bits inforward
, butoptimum-intel
's models would always run them.Do you think this could work? I'm happy to work on it myself :)
I am not sure about that because we have our past key values, which have a different shape compared to the standard cache. Do you plan to remove tuples for all models? I see that only some models like llama have this plan, and other models like gpt2 and opt still keep the tuple format past key values.
The long-term plan is to shift all models to the new Cache
classes. We're focusing on new models first 🤗
DynamicCache
and the old tuple format should have the exact same shapes!
Also cc @echarlaix here
We talked offline and optimum will need an update indeed
Can we always keep past_key_values as an option?
If the plan is to shift towards removing past_key_values, I was wondering if there were any examples that I could see?
@brucewlee the cache class that is now being used by default on modern model architectures supports seamless conversion to/from the legacy cache types 🤗 All a custom script needs to do is to convert from/to the legacy format before/after the model forward pass
See the class reference here
Feature request
I see llama will remove tuple past key values in 4.43.
Motivation
The model outputs should always keep tuple type as an option. If we remove this, all models will be failed in jit trace.
Your contribution
@gante Do you mind taking a look at it? Thx!
cc @amyeroberts