huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
133.92k stars 26.78k forks source link

Llama-3.2-11B-Vision-Instruct vocab size vs. lm head mismatch #33819

Open harshil-shah opened 3 weeks ago

harshil-shah commented 3 weeks ago

System Info

Who can help?

@ArthurZucker

Information

Tasks

Reproduction

Hi,

It seems there is a mismatch between the vocab size in the MllamaProcessor and the size of the lm_head weight matrix. Trying to call resize_token_embeddings doesn't fix this. This means that it is not possible to do training. Minimal example:

import requests
from PIL import Image
from transformers import MllamaForConditionalGeneration, MllamaProcessor

MODEL_NAME = "meta-llama/Llama-3.2-11B-Vision-Instruct"

processor = MllamaProcessor.from_pretrained(MODEL_NAME)
model = MllamaForConditionalGeneration.from_pretrained(MODEL_NAME)

print(f"{len(processor.tokenizer) = }")
print(f"Before resize: {model.language_model.lm_head.weight.shape = }")

model.resize_token_embeddings(len(processor.tokenizer))

print(f"After resize: {model.language_model.lm_head.weight.shape = }")

url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg"
image = Image.open(requests.get(url, stream=True).raw)

messages = [
    {"role": "user", "content": [
        {"type": "image"},
        {"type": "text", "text": "If I had to write a haiku for this one, it would be: "}
    ]}
]
input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(
    image,
    input_text,
    add_special_tokens=False,
    return_tensors="pt",
).to(model.device)

output = model(**inputs, labels=inputs.input_ids)

This outputs:

len(processor.tokenizer) = 128257
Before resize: model.language_model.lm_head.weight.shape = torch.Size([128256, 4096])
After resize: model.language_model.lm_head.weight.shape = torch.Size([128256, 4096])

And then errors with:

File ~/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/.venv/lib/python3.11/site-packages/transformers/models/mllama/modeling_mllama.py:2188, in MllamaForConditionalGeneration.forward(self, input_ids, pixel_values, aspect_ratio_mask, aspect_ratio_ids, attention_mask, cross_attention_mask, cross_attention_states, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, num_logits_to_keep)
   2185     cross_attention_mask = cross_attention_mask[:, :, cache_position]
   2186     full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, cache_position]
-> 2188 outputs = self.language_model(
   2189     input_ids=input_ids,
   2190     attention_mask=attention_mask,
   2191     position_ids=position_ids,
   2192     cross_attention_states=cross_attention_states,
   2193     cross_attention_mask=cross_attention_mask,
   2194     full_text_row_masked_out_mask=full_text_row_masked_out_mask,
   2195     past_key_values=past_key_values,
   2196     use_cache=use_cache,
   2197     inputs_embeds=inputs_embeds,
   2198     labels=labels,
   2199     output_hidden_states=output_hidden_states,
   2200     output_attentions=output_attentions,
   2201     return_dict=return_dict,
   2202     cache_position=cache_position,
   2203     num_logits_to_keep=num_logits_to_keep,
   2204 )
   2206 return outputs

File ~/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/.venv/lib/python3.11/site-packages/transformers/models/mllama/modeling_mllama.py:1961, in MllamaForCausalLM.forward(self, input_ids, attention_mask, position_ids, cross_attention_states, cross_attention_mask, full_text_row_masked_out_mask, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, num_logits_to_keep)
   1959     # Enable model parallelism
   1960     shift_labels = shift_labels.to(shift_logits.device)
-> 1961     loss = loss_fct(shift_logits, shift_labels)
   1963 if not return_dict:
   1964     output = (logits,) + outputs[1:]

File ~/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/.venv/lib/python3.11/site-packages/torch/nn/modules/loss.py:1179, in CrossEntropyLoss.forward(self, input, target)
   1178 def forward(self, input: Tensor, target: Tensor) -> Tensor:
-> 1179     return F.cross_entropy(input, target, weight=self.weight,
   1180                            ignore_index=self.ignore_index, reduction=self.reduction,
   1181                            label_smoothing=self.label_smoothing)

File ~/.venv/lib/python3.11/site-packages/torch/nn/functional.py:3059, in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)
   3057 if size_average is not None or reduce is not None:
   3058     reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 3059 return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)

IndexError: Target 128256 is out of bounds.

Expected behavior

The vocab size of the processor and model should match.

LysandreJik commented 3 weeks ago

cc @qubvel maybe? :)

phionex2 commented 3 weeks ago

Hi all,

I’ve reviewed the issue, and it seems that the mismatch between the vocab size of the MllamaProcessor and the lm_head in the model is causing the IndexError. When the resize_token_embeddings() method is called, it does not appear to resize the lm_head weight matrix, which is leading to the error during training.

One way to address this can be the resize_token_embeddings() function also resizes the lm_head layer so that it aligns with the new vocabulary size. If anyone has any suggestions or additional details that could help, feel free to share them.

Looking forward to your feedback.

qubvel commented 3 weeks ago

Hi @phionex2, thanks for opening the issue! Please see this discussion regarding mismatch and how to enable the model finetuning

TLDR; the main idea is that the image token is not intended to be trained and should be masked

image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.image_token)
labels[labels == image_token_id] = -100
phionex2 commented 2 weeks ago

Hi @qubvel , thanks for the heads-up!

I went through the discussion, and it seems the main issue stems from the image token not being intended for training. Based on the suggestion in the thread, masking the image token in the labels seems like the right approach.

image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.image_token)
labels[labels == image_token_id] = -100

This effectively prevents the image token from contributing to the loss calculation during training.

ArthurZucker commented 1 week ago

Indeed! I think we can close this? It is "expected" from the way the model was designed that there are missmatches between lm head and embedding unfortunately 😢