huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
133.03k stars 26.55k forks source link

Index error while pretraining Flava #27855

Closed ferjorosa closed 9 months ago

ferjorosa commented 10 months ago

System Info

transformers==4.35.2

Who can help?

@ArthurZucker @younesbelkada @amyeroberts

Information

Tasks

Reproduction

Error is thrown when doing pretraining with an itm_labels tensor that contains both 0s and 1s. Just as a reminder, to execute the ITM task successfully, pairs of image descriptions that do not match are required. The unmatched pairs are identified with a 1 in the itm_labels` list.

itm_labels = torch.tensor([0,1,0,0,0])

itm_outputs = model(
    # Text
    input_ids=text_inputs["input_ids"],  # Text input
    token_type_ids=text_inputs["token_type_ids"], 
    attention_mask=text_inputs["attention_mask"],  # Text attention mask
    input_ids_masked=text_inputs["input_ids_masked"],  # MLM masked inputs
    mlm_labels=text_inputs["mlm_labels"],  # MLM labels, has a different value than -100 if masked in input_ids_masked

    # Image
    pixel_values=image_inputs["pixel_values"],  # Image input
    bool_masked_pos=image_inputs["bool_masked_pos"],  # MIM mask (part of DALLE output), indicates which patches are masked (1) and which are not (0)
    codebook_pixel_values=image_inputs["codebook_pixel_values"],  # Information necessary for MIM labels

    # Pure Multimodal
    itm_labels=itm_labels
)

itm_outputs.loss_info

Error:

---------------------------------------------------------------------------

IndexError                                Traceback (most recent call last)

[<ipython-input-11-d659e68fc452>](https://localhost:8080/#) in <cell line: 3>()
      1 itm_labels = torch.tensor([0,1,0,0,0])
      2 
----> 3 itm_outputs = model(
      4     # Text
      5     input_ids=text_inputs["input_ids"],  # Text input

2 frames

[/usr/local/lib/python3.10/dist-packages/transformers/models/flava/modeling_flava.py](https://localhost:8080/#) in forward(self, input_ids, input_ids_masked, pixel_values, codebook_pixel_values, attention_mask, token_type_ids, bool_masked_pos, position_ids, image_attention_mask, skip_unmasked_multimodal_encoder, mlm_labels, mim_labels, itm_labels, output_attentions, output_hidden_states, return_dict, return_loss)
   1966 
   1967             if pos_mask is not None:
-> 1968                 sequence_for_image = sequence_for_image[pos_mask]
   1969             if mim_labels is not None:
   1970                 mim_labels = self._resize_to_2d(mim_labels)

IndexError: The shape of the mask [5] at index 0 does not match the shape of the indexed tensor [1, 196, 768] at index 0

In order to properly reproduce this error, I have also prepared a Google colab notebook, which can be found here

As a side note, this error may go unnoticed if all items in itm_labels are 0s, indicating that they all match, or if they are all 1s, signifying that none of them match. However, it's important to comment that in the code, when itm_labels contains all 1s, it is automatically translated into all 0s. This automatic "translation" may result in unexpected behaviours for the user.

Expected behavior

The error occurs because inside Flava's code the pos_mask is applied multiple times. It is first applied on line 1953 and then on lines 1968 (MMM-image) and 1991 (MMM-text) of modeling_flava.py. I think it would be fixed by just removing the second and third application of the mask.

ArthurZucker commented 10 months ago

Hey! Thanks for reporting 🤗 would you like to open a PR for a fix?

ferjorosa commented 10 months ago

Hi, yes. I have created a PR. Could you take a look into it?

Thanks

github-actions[bot] commented 9 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.