huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.13k stars 27.05k forks source link

LLaVA-OneVision mismatch between image features and image tokens #34625

Open agadetsky opened 1 week ago

agadetsky commented 1 week ago

System Info

Who can help?

@amyeroberts @qubvel @ArthurZucker @itaz

Information

Tasks

Reproduction

from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration, BitsAndBytesConfig
import numpy as np
import torch
from PIL import Image

model_id = "llava-hf/llava-onevision-qwen2-72b-ov-hf"

# specify how to quantize the model
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)

model = LlavaOnevisionForConditionalGeneration.from_pretrained(
    model_id, 
    torch_dtype=torch.float16, 
    low_cpu_mem_usage=True,
    attn_implementation="flash_attention_2",
    quantization_config=quantization_config
)
processor = AutoProcessor.from_pretrained(model_id)

conversation = [
    {
        "role": "user",
        "content": [{"type": "text", "text": "Describe the image"}, {"type": "image"}] ,
    },
]

prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
image = Image.fromarray(np.random.randn(243, 387, 3).astype('uint8'), 'RGB')
inputs = processor(
    images=image,
    text=prompt,
    return_tensors="pt"
).to(model.device, torch.float16)

output_ids = model.generate(**inputs, max_new_tokens=32)

Error is the following

ValueError                                Traceback (most recent call last)
Cell In[235], line 16
      9 image = Image.fromarray(np.random.randn(243, 387, 3).astype('uint8'), 'RGB')
     10 inputs = processor(
     11     images=image,
     12     text=prompt,
     13     return_tensors="pt"
     14 ).to(model.device, torch.float16)
---> 16 output_ids = model.generate(**inputs, max_new_tokens=32)

File ~/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/transformers/generation/utils.py:2215, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
   2207     input_ids, model_kwargs = self._expand_inputs_for_generation(
   2208         input_ids=input_ids,
   2209         expand_size=generation_config.num_return_sequences,
   2210         is_encoder_decoder=self.config.is_encoder_decoder,
   2211         **model_kwargs,
   2212     )
   2214     # 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
-> 2215     result = self._sample(
   2216         input_ids,
   2217         logits_processor=prepared_logits_processor,
   2218         stopping_criteria=prepared_stopping_criteria,
   2219         generation_config=generation_config,
   2220         synced_gpus=synced_gpus,
   2221         streamer=streamer,
   2222         **model_kwargs,
   2223     )
   2225 elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
   2226     # 11. prepare beam search scorer
   2227     beam_scorer = BeamSearchScorer(
   2228         batch_size=batch_size,
   2229         num_beams=generation_config.num_beams,
   (...)
   2234         max_length=generation_config.max_length,
   2235     )

File ~/.local/lib/python3.10/site-packages/transformers/generation/utils.py:3206, in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, **model_kwargs)
   3203 model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
   3205 # forward pass to get next token
-> 3206 outputs = self(**model_inputs, return_dict=True)
   3208 # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
   3209 model_kwargs = self._update_model_kwargs_for_generation(
   3210     outputs,
   3211     model_kwargs,
   3212     is_encoder_decoder=self.config.is_encoder_decoder,
   3213 )

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File ~/.local/lib/python3.10/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File ~/.local/lib/python3.10/site-packages/transformers/models/llava_onevision/modeling_llava_onevision.py:684, in LlavaOnevisionForConditionalGeneration.forward(self, input_ids, pixel_values, image_sizes, pixel_values_videos, image_sizes_videos, attention_mask, position_ids, past_key_values, inputs_embeds, vision_feature_layer, vision_feature_select_strategy, vision_aspect_ratio, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, num_logits_to_keep)
    681 n_image_features = image_features.shape[0]
    683 if n_image_tokens != n_image_features:
--> 684     raise ValueError(
    685         f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
    686     )
    687 special_image_mask = (
    688     (input_ids == self.config.image_token_index)
    689     .unsqueeze(-1)
    690     .expand_as(inputs_embeds)
    691     .to(inputs_embeds.device)
    692 )
    693 image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)

ValueError: Image features and image tokens do not match: tokens: 1890, features 1944

Expected behavior

Given that LLaVA-OneVision can work with any resolutions, the model is expected to successfully generate the output.

zucchini-nlp commented 3 days ago

@agadetsky , it seems like there are differences in how we compute number of image tokens in the processing code and in modeling. Might be related to prev bugs with numerical issues when the image resolution is on the edge case of all possible grid resolutiions (like 337 here). I'll take a look and see where is the precision error coming

chenweize1998 commented 15 hours ago

Hi @zucchini-nlp , have you managed to identify the issue? I'm encountering the same error while using llava-hf/llava-v1.6-mistral-7b-hf. I haven't pinpointed the specific data causing the error, as it occurs midway through training. Could you also take a look at the modeling file of llava next? Maybe some calculation on the anyres is mismatched?

zucchini-nlp commented 13 hours ago

@chenweize1998 yes, that is most probably the anyres calculations. Unfortunately I didn't have time to look in more detail, will try to have a look today

EDIT: found the place where there was precision error and opened a PR to fix

chenweize1998 commented 12 hours ago

@zucchini-nlp Thanks for looking into this! I've pinpointed the batch of data causing the issue and uploaded it here. The problem specifically originates from the first data point in the batch. Hope it helps with debugging.

Additionally, here’s a minimal script to reproduce the error (assuming the data point is downloaded as ./tmp.bin):

from transformers import AutoModelForVision2Seq
import torch

# Load the model
model = AutoModelForVision2Seq.from_pretrained(
    "llava-hf/llava-v1.6-mistral-7b-hf", 
    torch_dtype=torch.bfloat16
).to("cuda:0")

# Load the problematic input
inputs = torch.load("tmp.bin")
# Note: inputs['input_ids'][0] triggers the error

for k, v in inputs.items():
    inputs[k] = v.to("cuda:0")

# Generate outputs
outputs = model(**inputs)

I'm using torch==2.4.0 and transformers==4.46.2. Let me know if you need more details.