Open huangyuxiang03 opened 2 weeks ago
Hey @huangyuxiang03 ! Yes, any model that supports new cache object should be able to work with quantized cache out-of-the-box. The problem with Phi might be in https://github.com/huggingface/transformers/issues/32945#issuecomment-2324184723 or the issue linked there.
I will run the provided script to see what went wrong, but I also recommend to try generating with the model's generate()
method to see if the issue is in the custom generation loop or not. You can also find information about KV cache here (https://huggingface.co/docs/transformers/v4.44.0/kv_cache) 😄
Hey @huangyuxiang03 ! Yes, any model that supports new cache object should be able to work with quantized cache out-of-the-box. The problem with Phi might be in #32945 (comment) or the issue linked there.
I will run the provided script to see what went wrong, but I also recommend to try generating with the model's
generate()
method to see if the issue is in the custom generation loop or not. You can also find information about KV cache here (https://huggingface.co/docs/transformers/v4.44.0/kv_cache) 😄
Thanks for your reply. I think it might not be the same problem with #32945, as I tried the code provided in #32945 but I cannot reproduce the corresponding bug. The current implementation of Phi-3 does not support kv cache quantization, and it does not have a custom generation function. Thus I assume that the simplest way of applying kv cache quantization is to do so out-of-the-box, or simply turn on the variable of allowing kv cache quantization in Phi3PreTrainedModel. I sincerely appreciate your support.
Hey @huangyuxiang03 ! Yes, any model that supports new cache object should be able to work with quantized cache out-of-the-box. The problem with Phi might be in #32945 (comment) or the issue linked there. I will run the provided script to see what went wrong, but I also recommend to try generating with the model's
generate()
method to see if the issue is in the custom generation loop or not. You can also find information about KV cache here (https://huggingface.co/docs/transformers/v4.44.0/kv_cache) 😄Thanks for your reply. I think it might not be the same problem with #32945, as I tried the code provided in #32945 but I cannot reproduce the corresponding bug. The current implementation of Phi-3 does not support kv cache quantization, and it does not have a custom generation function. Thus I assume that the simplest way of applying kv cache quantization is to do so out-of-the-box, or simply turn on the variable of allowing kv cache quantization in Phi3PreTrainedModel. I sincerely appreciate your support.
In fact, I am not sure about if the performance degredation is caused by some bug or by quantization. From my experience of kv cache quantization, it does occur pretty large performance degredation when applying 2-bits quantization to certain models. However, Phi-3 + 2-bits quantization almost loses all language capability, which seems to be a little weird.
@huangyuxiang03 I see, quantization might result in performance degradation, especially if you're using 2bits only. Depending on the backend, there are some params which work most optimal. Seems like you're using the defaults and Quanto, which is expected to have a decent generation quality overall.
The current implementation of Phi-3 does not support kv cache quantization, Indeed it doesn't have the flag set, and afair there wasn't any specific reason as to why quantized cache cannot be supported. Probably, it was missed when adding Phi-3 cache support. Let me add the flag and try running generation
@huangyuxiang03 Okey, got it working. The issue was that cache_position
was not being passed and we can't infer it correctly if the cache is quantized. Also feel free to open a PR to add supports_quantized
flag in Phi-3 model, just tested that it works
Generation by transformers and the custom loop match in the below script :)
from transformers.cache_utils import QuantoQuantizedCache, QuantizedCacheConfig
from transformers import GenerationConfig, AutoModelForCausalLM, AutoTokenizer
import torch
model_identifier = 'microsoft/Phi-3-mini-4k-instruct'
model = AutoModelForCausalLM.from_pretrained(model_identifier, device_map='cuda:0', torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_identifier)
inputs = tokenizer("Hello, my name is", return_tensors="pt").to("cuda:0", torch.bfloat16)
out = model.generate(**inputs, max_length=50, cache_implementation="quantized", cache_config={"nbits": 2, "backend": "quanto"})
print(tokenizer.batch_decode(out, skip_special_tokens=True))
@torch.no_grad()
def gen(model, input_ids, attention_mask, max_new_tokens, eos_token_id):
past_key_values = QuantoQuantizedCache(QuantizedCacheConfig(nbits=2, compute_dtype=torch.bfloat16))
cache_position = torch.arange(input_ids.shape[1], device="cuda:0")
output = model(input_ids, attention_mask=attention_mask, past_key_values=past_key_values, cache_position=cache_position)
past_key_values = output.past_key_values
generated_tokens = []
next_id = output.logits[:, -1, :].argmax(dim=-1, keepdim=True)
generated_tokens.append(next_id.item())
next_mask = torch.ones((attention_mask.shape[0], 1), device="cuda:0")
attention_mask = torch.cat([attention_mask, next_mask], dim=-1)
cache_position = cache_position[-1:] + 1
for _ in range(max_new_tokens-1):
output = model(next_id, attention_mask=attention_mask, past_key_values=past_key_values, cache_position=cache_position)
past_key_values = output.past_key_values
next_id = output.logits[:, -1, :].argmax(dim=-1, keepdim=True)
if next_id.item() == eos_token_id:
break
generated_tokens.append(next_id.item())
next_mask = torch.ones((attention_mask.shape[0], 1), device="cuda:0")
attention_mask = torch.cat([attention_mask, next_mask], dim=-1)
cache_position = cache_position[-1:] + 1
generated_tokens = torch.tensor(generated_tokens, device=input_ids.device, dtype=input_ids.dtype).unsqueeze(0)
input_ids = torch.cat((input_ids, generated_tokens), dim=-1)
print(tokenizer.batch_decode(input_ids, skip_special_tokens=True))
return input_ids
gen_ids = gen(model, inputs.input_ids, inputs.attention_mask, max_new_tokens=30, eos_token_id=model.config.eos_token_id)
print(gen_ids)
4
Thanks for the provided code. I'm a bit confused that here infers cache_position
if the cache_position
passed in is None
. From my observation, past_key_values.get_seq_length()
gives correct length, thus there shouldn't be any problem of inferring cache_position
.
Besides, [https://github.com/huggingface/transformers/blob/2d3708581742b45f190903a740cd1e69030dbc9f/src/transformers/models/phi3/modeling_phi3.py#L1249] doesn't pass cache_position
to Phi3Model
, which I think might be a bug.
The code provided above gives the same output as the code snippet I provided at the very beginning. I'm testing with phi-3-mini-128K-instruct, and the model seems fine within 4096 context length. If the input sequence is longer, the model starts acting strange and gives random output.
@huangyuxiang03 If the garbage output starts after a certain length, then it is most probably related to the linked issue. For cache_position
, yes we try to infer it from cache length but the current impl of quantized cache uses a deprecated attribute for seq length. In general I'd recommend to pass in your own cache_position
as some cache objects don't/can't return seq length.
System Info
transformers
version: 4.44.2Who can help?
@ArthurZucker @gante
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Expected behavior
The code snippet provided above generates random output for phi-3-mini-128K, which is a model that does not originally support KV Cache quantization.
However, from my understanding of the quantized cache supported in Hugging Face Transformers, one can simple replace an instance of
DynamicCache
toQuantoQuantizedCache
to enable KV Cache quantization. This is also mentioned in https://github.com/huggingface/transformers/pull/30483#issuecomment-2122117638. Phi-3-mini-128K is a quite-standard decoder-only transformer-based model with only a few modifications on model structure compared with Llama, thus I believe that if the quantized cache can work correctly (which it does) on Llama, it can work correctly on Phi-3. The code snippet can generate high quality output on Llama, but it generates random tokens on Phi-3.Besides, could you provide a readme to teach model contributors to enable KV Cache quantization?