huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132.11k stars 26.32k forks source link

Questions about supporting KV Cache quantization for models that do not support quantized cache now #33231

Open huangyuxiang03 opened 2 weeks ago

huangyuxiang03 commented 2 weeks ago

System Info

Who can help?

@ArthurZucker @gante

Information

Tasks

Reproduction

from transformers.cache_utils import QuantoQuantizedCache, QuantizedCacheConfig

BS = 1024
@torch.no_grad()
def gen(model, input_ids, max_new_tokens, eos_token_id):
    past_key_values = QuantoQuantizedCache(QuantizedCacheConfig(nbits=2, compute_dtype=torch.bfloat16))
    for b in range(0, input_ids.shape[-1], BS):
        e = min(input_ids.shape[-1], b + BS)
        output = model(input_ids[:, b:e], past_key_values=past_key_values)
        past_key_values = output.past_key_values

    generated_tokens = []
    input_id = output.logits[:, -1, :].argmax(dim=-1, keepdim=True)
    generated_tokens.append(input_id.item())

    for _ in range(max_new_tokens-1):
        output = model(input_id, past_key_values=past_key_values)
        past_key_values = output.past_key_values
        input_id = output.logits[:, -1, :].argmax(dim=-1, keepdim=True)
        if input_id.item() == eos_token_id:
            break
        generated_tokens.append(input_id.item())
    generated_tokens = torch.tensor(generated_tokens, device=input_ids.device, dtype=input_ids.dtype).unsqueeze(0)
    input_ids = torch.cat((input_ids, generated_tokens), dim=-1)
    return input_ids

Expected behavior

The code snippet provided above generates random output for phi-3-mini-128K, which is a model that does not originally support KV Cache quantization.

However, from my understanding of the quantized cache supported in Hugging Face Transformers, one can simple replace an instance of DynamicCache to QuantoQuantizedCache to enable KV Cache quantization. This is also mentioned in https://github.com/huggingface/transformers/pull/30483#issuecomment-2122117638. Phi-3-mini-128K is a quite-standard decoder-only transformer-based model with only a few modifications on model structure compared with Llama, thus I believe that if the quantized cache can work correctly (which it does) on Llama, it can work correctly on Phi-3. The code snippet can generate high quality output on Llama, but it generates random tokens on Phi-3.

Besides, could you provide a readme to teach model contributors to enable KV Cache quantization?

zucchini-nlp commented 2 weeks ago

Hey @huangyuxiang03 ! Yes, any model that supports new cache object should be able to work with quantized cache out-of-the-box. The problem with Phi might be in https://github.com/huggingface/transformers/issues/32945#issuecomment-2324184723 or the issue linked there.

I will run the provided script to see what went wrong, but I also recommend to try generating with the model's generate() method to see if the issue is in the custom generation loop or not. You can also find information about KV cache here (https://huggingface.co/docs/transformers/v4.44.0/kv_cache) 😄

huangyuxiang03 commented 2 weeks ago

Hey @huangyuxiang03 ! Yes, any model that supports new cache object should be able to work with quantized cache out-of-the-box. The problem with Phi might be in #32945 (comment) or the issue linked there.

I will run the provided script to see what went wrong, but I also recommend to try generating with the model's generate() method to see if the issue is in the custom generation loop or not. You can also find information about KV cache here (https://huggingface.co/docs/transformers/v4.44.0/kv_cache) 😄

Thanks for your reply. I think it might not be the same problem with #32945, as I tried the code provided in #32945 but I cannot reproduce the corresponding bug. The current implementation of Phi-3 does not support kv cache quantization, and it does not have a custom generation function. Thus I assume that the simplest way of applying kv cache quantization is to do so out-of-the-box, or simply turn on the variable of allowing kv cache quantization in Phi3PreTrainedModel. I sincerely appreciate your support.

huangyuxiang03 commented 2 weeks ago

Hey @huangyuxiang03 ! Yes, any model that supports new cache object should be able to work with quantized cache out-of-the-box. The problem with Phi might be in #32945 (comment) or the issue linked there. I will run the provided script to see what went wrong, but I also recommend to try generating with the model's generate() method to see if the issue is in the custom generation loop or not. You can also find information about KV cache here (https://huggingface.co/docs/transformers/v4.44.0/kv_cache) 😄

Thanks for your reply. I think it might not be the same problem with #32945, as I tried the code provided in #32945 but I cannot reproduce the corresponding bug. The current implementation of Phi-3 does not support kv cache quantization, and it does not have a custom generation function. Thus I assume that the simplest way of applying kv cache quantization is to do so out-of-the-box, or simply turn on the variable of allowing kv cache quantization in Phi3PreTrainedModel. I sincerely appreciate your support.

In fact, I am not sure about if the performance degredation is caused by some bug or by quantization. From my experience of kv cache quantization, it does occur pretty large performance degredation when applying 2-bits quantization to certain models. However, Phi-3 + 2-bits quantization almost loses all language capability, which seems to be a little weird.

zucchini-nlp commented 2 weeks ago

@huangyuxiang03 I see, quantization might result in performance degradation, especially if you're using 2bits only. Depending on the backend, there are some params which work most optimal. Seems like you're using the defaults and Quanto, which is expected to have a decent generation quality overall.

The current implementation of Phi-3 does not support kv cache quantization, Indeed it doesn't have the flag set, and afair there wasn't any specific reason as to why quantized cache cannot be supported. Probably, it was missed when adding Phi-3 cache support. Let me add the flag and try running generation

zucchini-nlp commented 2 weeks ago

@huangyuxiang03 Okey, got it working. The issue was that cache_position was not being passed and we can't infer it correctly if the cache is quantized. Also feel free to open a PR to add supports_quantized flag in Phi-3 model, just tested that it works

Generation by transformers and the custom loop match in the below script :)

from transformers.cache_utils import QuantoQuantizedCache, QuantizedCacheConfig
from transformers import GenerationConfig, AutoModelForCausalLM, AutoTokenizer
import torch

model_identifier = 'microsoft/Phi-3-mini-4k-instruct'
model = AutoModelForCausalLM.from_pretrained(model_identifier, device_map='cuda:0', torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_identifier)

inputs = tokenizer("Hello, my name is", return_tensors="pt").to("cuda:0", torch.bfloat16)
out = model.generate(**inputs, max_length=50, cache_implementation="quantized", cache_config={"nbits": 2, "backend": "quanto"})
print(tokenizer.batch_decode(out, skip_special_tokens=True))

@torch.no_grad()
def gen(model, input_ids, attention_mask, max_new_tokens, eos_token_id):
    past_key_values = QuantoQuantizedCache(QuantizedCacheConfig(nbits=2, compute_dtype=torch.bfloat16))
    cache_position = torch.arange(input_ids.shape[1], device="cuda:0")
    output = model(input_ids, attention_mask=attention_mask, past_key_values=past_key_values, cache_position=cache_position)
    past_key_values = output.past_key_values

    generated_tokens = []
    next_id = output.logits[:, -1, :].argmax(dim=-1, keepdim=True)
    generated_tokens.append(next_id.item())

    next_mask = torch.ones((attention_mask.shape[0], 1), device="cuda:0")
    attention_mask = torch.cat([attention_mask, next_mask], dim=-1)
    cache_position = cache_position[-1:] + 1

    for _ in range(max_new_tokens-1):
        output = model(next_id, attention_mask=attention_mask, past_key_values=past_key_values, cache_position=cache_position)
        past_key_values = output.past_key_values
        next_id = output.logits[:, -1, :].argmax(dim=-1, keepdim=True)
        if next_id.item() == eos_token_id:
            break
        generated_tokens.append(next_id.item())

        next_mask = torch.ones((attention_mask.shape[0], 1), device="cuda:0")
        attention_mask = torch.cat([attention_mask, next_mask], dim=-1)
        cache_position = cache_position[-1:] + 1

    generated_tokens = torch.tensor(generated_tokens, device=input_ids.device, dtype=input_ids.dtype).unsqueeze(0)
    input_ids = torch.cat((input_ids, generated_tokens), dim=-1)
    print(tokenizer.batch_decode(input_ids, skip_special_tokens=True))
    return input_ids

gen_ids = gen(model, inputs.input_ids, inputs.attention_mask, max_new_tokens=30, eos_token_id=model.config.eos_token_id)
print(gen_ids)
huangyuxiang03 commented 2 weeks ago
4

Thanks for the provided code. I'm a bit confused that here infers cache_position if the cache_position passed in is None. From my observation, past_key_values.get_seq_length() gives correct length, thus there shouldn't be any problem of inferring cache_position.

Besides, [https://github.com/huggingface/transformers/blob/2d3708581742b45f190903a740cd1e69030dbc9f/src/transformers/models/phi3/modeling_phi3.py#L1249] doesn't pass cache_position to Phi3Model, which I think might be a bug.

The code provided above gives the same output as the code snippet I provided at the very beginning. I'm testing with phi-3-mini-128K-instruct, and the model seems fine within 4096 context length. If the input sequence is longer, the model starts acting strange and gives random output.

zucchini-nlp commented 2 weeks ago

@huangyuxiang03 If the garbage output starts after a certain length, then it is most probably related to the linked issue. For cache_position, yes we try to infer it from cache length but the current impl of quantized cache uses a deprecated attribute for seq length. In general I'd recommend to pass in your own cache_position as some cache objects don't/can't return seq length.