huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.7k stars 26.94k forks source link

LLaVA cannot use beam search after 4.43.0 #32234

Closed snorfyang closed 3 months ago

snorfyang commented 3 months ago

System Info

transformers >= 4.43.0

Who can help?

@zucchini-nlp

Information

Tasks

Reproduction

from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
import torch
from PIL import Image
import requests

processor = LlavaNextProcessor.from_pretrained("llava-v1.6-vicuna-7b-hf")

model = LlavaNextForConditionalGeneration.from_pretrained("llava-v1.6-vicuna-7b-hf", torch_dtype=torch.float16, low_cpu_mem_usage=True) 
model.to("cuda:0")

# prepare image and text prompt, using the appropriate prompt template
url = "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"
image = Image.open(requests.get(url, stream=True).raw)

# Define a chat histiry and use `apply_chat_template` to get correctly formatted prompt
# Each value in "content" has to be a list of dicts with types ("text", "image") 
conversation = [
    {

      "role": "user",
      "content": [
          {"type": "text", "text": "What is shown in this image?"},
          {"type": "image"},
        ],
    },
]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)

inputs = processor(prompt, image, return_tensors="pt").to("cuda:0")

# autoregressively complete prompt
output = model.generate(**inputs, max_new_tokens=100, num_beams=5)

print(processor.decode(output[0], skip_special_tokens=True))
NotImplementedError: Make sure that a `_reorder_cache` function is correctly implemented

Expected behavior

_reorder_cache function is removed in #31898

zucchini-nlp commented 3 months ago

Hey! Thanks for reporting. Yes, the issue is probably related to the Llama moving completely to new cache format while Llava wasn't changes appropriately for that.

I'll make a PR today. You can update your transformers with !pip install --upgrade git+https://github.com/huggingface/transformers.git after it is merged to pull the fix into your env :)

snorfyang commented 3 months ago

Also, can't use contrastive search either even I turned to 4.42.4

import requests
from PIL import Image

import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration

model_id = "llava-hf/llava-1.5-7b-hf"
model = LlavaForConditionalGeneration.from_pretrained(
    model_id, 
    torch_dtype=torch.float16, 
    low_cpu_mem_usage=True, 
).to(0)

processor = AutoProcessor.from_pretrained(model_id)

# Define a chat histiry and use `apply_chat_template` to get correctly formatted prompt
# Each value in "content" has to be a list of dicts with types ("text", "image") 
conversation = [
    {

      "role": "user",
      "content": [
          {"type": "text", "text": "What are these?"},
          {"type": "image"},
        ],
    },
]

processor.chat_template = "{% for message in messages %}{% if message['role'] != 'system' %}{{ message['role'].upper() + ': '}}{% endif %}{# Render all images first #}{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}{{ '<image>\n' }}{% endfor %}{# Render all text next #}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{{ content['text'] + ' '}}{% endfor %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %}"
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)

image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = processor(prompt, raw_image, return_tensors='pt').to(0, torch.float16)

output = model.generate(**inputs, max_new_tokens=200, penalty_alpha=0.6, top_k=4)
print(processor.decode(output[0][2:], skip_special_tokens=True))
IndexError                                Traceback (most recent call last)
Cell In[2], line 36
     33 raw_image = Image.open(requests.get(image_file, stream=True).raw)
     34 inputs = processor(prompt, raw_image, return_tensors='pt').to(0, torch.float16)
---> 36 output = model.generate(**inputs, max_new_tokens=200, penalty_alpha=0.6, top_k=4)
     37 print(processor.decode(output[0][2:], skip_special_tokens=True))

File /usr/local/lib/python3.11/dist-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File /usr/local/lib/python3.11/dist-packages/transformers/generation/utils.py:1887, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
   1881     if self._is_stateful:
   1882         # Just like assisted generation, we need to be able to rollback to a previous state (see comment above)
   1883         raise ValueError(
   1884             f"contrastive search is not supported with stateful models, such as {self.__class__.__name__}"
   1885         )
-> 1887     result = self._contrastive_search(
   1888         input_ids,
   1889         logits_processor=prepared_logits_processor,
   1890         stopping_criteria=prepared_stopping_criteria,
   1891         generation_config=generation_config,
   1892         synced_gpus=synced_gpus,
   1893         streamer=streamer,
   1894         **model_kwargs,
   1895     )
   1897 elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
   1898     # 11. prepare logits warper
   1899     prepared_logits_warper = (
   1900         self._get_logits_warper(generation_config, device=input_ids.device)
   1901         if generation_config.do_sample
   1902         else None
   1903     )

File /usr/local/lib/python3.11/dist-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File /usr/local/lib/python3.11/dist-packages/transformers/generation/utils.py:2367, in GenerationMixin._contrastive_search(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, **model_kwargs)
   2362 else:
   2363     # compute the candidate tokens by the language model and collect their hidden_states
   2364     # assembles top_k_ids into batch of size k
   2365     next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs)
-> 2367     outputs = self(
   2368         **next_model_inputs,
   2369         return_dict=True,
   2370         output_hidden_states=True,
   2371         output_attentions=output_attentions,
   2372     )
   2374 # This is essential to avoid having a last reference to the big past K-V and double the necesary memory
   2375 # in the next loop
   2376 del next_model_inputs

File /usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File /usr/local/lib/python3.11/dist-packages/transformers/models/llava/modeling_llava.py:439, in LlavaForConditionalGeneration.forward(self, input_ids, pixel_values, attention_mask, position_ids, past_key_values, inputs_embeds, vision_feature_layer, vision_feature_select_strategy, labels, use_cache, output_attentions, output_hidden_states, return_dict)
    437     image_features = self.multi_modal_projector(selected_image_feature)
    438     inputs_embeds = inputs_embeds.to(image_features.dtype)
--> 439     inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
    440         image_features, inputs_embeds, input_ids, attention_mask, labels
    441     )
    443 # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
    444 # generation with cache
    445 elif past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
    446     # Retrieve the first layer to inspect the logits and mask out the hidden states
    447     # that are set to 0

File /usr/local/lib/python3.11/dist-packages/transformers/models/llava/modeling_llava.py:280, in LlavaForConditionalGeneration._merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels)
    278 num_images, num_image_patches, embed_dim = image_features.shape
    279 batch_size, sequence_length = input_ids.shape
--> 280 left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
    281 # 1. Create a mask to know where special image tokens are
    282 special_image_token_mask = input_ids == self.config.image_token_index

IndexError: index -1 is out of bounds for dimension 1 with size 0
freddyouellette commented 3 months ago

Hi @zucchini-nlp, did you make a PR for this? I'm still getting this when trying to use Llama 3.1 with beam search:

NotImplementedError: Make sure that a `_reorder_cache` function is correctly implemented in transformers.models.llama.modeling_llama to enable beam search for <class 'transformers.models.llama.modeling_llama.LlamaForCausalLM'>
zucchini-nlp commented 3 months ago

@freddyouellette sorry, I didn't have time to open a PR. Will raise priority for this issue and work on it on Monday.

Btw, it is weird that you can't do beam search with Llama 3.1, are you using the language model only (not as part of llava) and calling generate()?