huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
133.56k stars 26.69k forks source link

MllamaForCausalLM not returning past_key_values even with use_cache=True #34206

Open Zhangyanbo opened 1 day ago

Zhangyanbo commented 1 day ago

System Info

Who can help?

@amyeroberts @ArthurZucker

Information

Tasks

Reproduction

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn as nn
import os
import matplotlib.pyplot as plt
from promptmix import generate_probmix
from transformers import MllamaForConditionalGeneration, AutoProcessor

model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct" # https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct

tokenizer = AutoTokenizer.from_pretrained(model_id)

model = AutoModelForCausalLM.from_pretrained(
    model_id, 
    cache_dir=cache_directory,
    torch_dtype=torch.float16,
    device_map='auto',
    )

processor = AutoProcessor.from_pretrained(model_id, cache_dir=cache_directory)

input_ids = tokenizer.encode('Hi, tell me a story of frog.', add_special_tokens=False, return_tensors='pt').to(model.device)

with torch.no_grad():
    output = model.forward(input_ids=input_ids, use_cache=True)
output

Expected behavior

I expect to see a past_key_values in the output. However, I got None.

zucchini-nlp commented 22 hours ago

Hmm, actually we did Mllama quite similar to Idefics so the cache is not initialized by default when "use-cache=True". And yes, I think makes sense to init an empty cache if those are not special models like Gemma with special cache

Until the fix is there you can get pask-kv by passing model(**inputs, past_key_values=DynamicCache(), use_cache=True) but I see that the model weights will not be loaded proper way for CausalModel. In fact the ConditionalModel can deal with text-only input so for proper logits computation i'd recommend to use the ConditionalModel :)

ArthurZucker commented 21 hours ago

With use_cache we should probably just init a default cache for the user, or we opt for forcing users to pass a cache object

zucchini-nlp commented 21 hours ago

Yes, exactly. I can make a PR for that