huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.09k stars 27.03k forks source link

MllamaForCausalLM not returning past_key_values even with use_cache=True #34206

Open Zhangyanbo opened 1 month ago

Zhangyanbo commented 1 month ago

System Info

Who can help?

@amyeroberts @ArthurZucker

Information

Tasks

Reproduction

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn as nn
import os
import matplotlib.pyplot as plt
from promptmix import generate_probmix
from transformers import MllamaForConditionalGeneration, AutoProcessor

model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct" # https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct

tokenizer = AutoTokenizer.from_pretrained(model_id)

model = AutoModelForCausalLM.from_pretrained(
    model_id, 
    cache_dir=cache_directory,
    torch_dtype=torch.float16,
    device_map='auto',
    )

processor = AutoProcessor.from_pretrained(model_id, cache_dir=cache_directory)

input_ids = tokenizer.encode('Hi, tell me a story of frog.', add_special_tokens=False, return_tensors='pt').to(model.device)

with torch.no_grad():
    output = model.forward(input_ids=input_ids, use_cache=True)
output

Expected behavior

I expect to see a past_key_values in the output. However, I got None.

zucchini-nlp commented 1 month ago

Hmm, actually we did Mllama quite similar to Idefics so the cache is not initialized by default when "use-cache=True". And yes, I think makes sense to init an empty cache if those are not special models like Gemma with special cache

Until the fix is there you can get pask-kv by passing model(**inputs, past_key_values=DynamicCache(), use_cache=True) but I see that the model weights will not be loaded proper way for CausalModel. In fact the ConditionalModel can deal with text-only input so for proper logits computation i'd recommend to use the ConditionalModel :)

ArthurZucker commented 1 month ago

With use_cache we should probably just init a default cache for the user, or we opt for forcing users to pass a cache object

zucchini-nlp commented 1 month ago

Yes, exactly. I can make a PR for that

github-actions[bot] commented 1 day ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.