huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
Apache License 2.0
132.09k stars 26.31k forks source link

FLAVA not doing a forward pass #22103

Closed amariucaitheodor closed 1 year ago

amariucaitheodor commented 1 year ago

System Info

N.B. I do have PyTorch installed, I'm not sure why the tool can't find it:

python -c "import torch; print(torch.__version__)" 

Who can help?





Steps to reproduce the behavior (also a Colab notebook doing this):

  1. Get a datapoint for a forward pass (fetch_images is in the notebook above):
    pmd = datasets.load_dataset("facebook/pmd", "wit", use_auth_token=True, streaming=True)
    pmd_train_head = pmd['train'].take(2)
    pmd_train_head_with_images =, batched=True, batch_size=100, fn_kwargs={"num_threads": 20})
    datapoint = next(iter(pmd_train_head_with_images))
  2. Process the input:
    from transformers import FlavaProcessor, FlavaForPreTraining
    processor = FlavaProcessor.from_pretrained("facebook/flava-full")
    inputs = processor(
  3. Mask the text input for MLM:
    from transformers import DataCollatorForLanguageModeling, AutoTokenizer
    data_collator = DataCollatorForLanguageModeling(processor.tokenizer, mlm=True, mlm_probability=0.4, return_tensors="pt")
    inputs['input_ids'], inputs['input_ids_masked'] = data_collator.torch_mask_tokens(inputs=inputs['input_ids'],
    del inputs['special_tokens_mask']
  4. Do a forward pass:
    model = FlavaForPreTraining.from_pretrained("facebook/flava-full")
    outputs = model(**inputs)
    loss = outputs.loss
    print(f"loss: {loss}")

Expected behavior

I would expect the forward pass to not throw errors.

Actual behavior


IndexError                                Traceback (most recent call last)

<ipython-input-14-b821d73f49e9> in <module>
      1 model = FlavaForPreTraining.from_pretrained("facebook/flava-full")
----> 3 outputs = model(**inputs)
/usr/local/lib/python3.9/dist-packages/torch/nn/modules/ in _call_impl(self, *input, **kwargs)
   1192         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194             return forward_call(*input, **kwargs)
   1195         # Do not call functions when jit is used
   1196         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.9/dist-packages/transformers/models/flava/ in forward(self, input_ids, input_ids_masked, pixel_values, codebook_pixel_values, attention_mask, token_type_ids, bool_masked_pos, position_ids, image_attention_mask, skip_unmasked_multimodal_encoder, mlm_labels, mim_labels, itm_labels, output_attentions, output_hidden_states, return_dict, return_loss)
   1857         )
-> 1859         flava_masked_output = self.flava(
   1860             input_ids=input_ids_masked,
   1861             pixel_values=pixel_values,

/usr/local/lib/python3.9/dist-packages/torch/nn/modules/ in _call_impl(self, *input, **kwargs)
   1192         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194             return forward_call(*input, **kwargs)
   1195         # Do not call functions when jit is used
   1196         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.9/dist-packages/transformers/models/flava/ in forward(self, input_ids, pixel_values, attention_mask, token_type_ids, bool_masked_pos, position_ids, image_attention_mask, skip_multimodal_encoder, output_attentions, output_hidden_states, return_dict)
   1403         text_output = None
   1404         if input_ids is not None:
-> 1405             text_output = self.text_model(
   1406                 input_ids=input_ids,
   1407                 attention_mask=attention_mask,

/usr/local/lib/python3.9/dist-packages/torch/nn/modules/ in _call_impl(self, *input, **kwargs)
   1192         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194             return forward_call(*input, **kwargs)
   1195         # Do not call functions when jit is used
   1196         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.9/dist-packages/transformers/models/flava/ in forward(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, output_attentions, output_hidden_states, return_dict)
   1061         )
-> 1063         embedding_output = self.embeddings(
   1064             input_ids=input_ids,
   1065             token_type_ids=token_type_ids,

/usr/local/lib/python3.9/dist-packages/torch/nn/modules/ in _call_impl(self, *input, **kwargs)
   1192         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194             return forward_call(*input, **kwargs)
   1195         # Do not call functions when jit is used
   1196         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.9/dist-packages/transformers/models/flava/ in forward(self, input_ids, token_type_ids, position_ids)
    417                 token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
--> 419         inputs_embeds = self.word_embeddings(input_ids)
    420         token_type_embeddings = self.token_type_embeddings(token_type_ids)

/usr/local/lib/python3.9/dist-packages/torch/nn/modules/ in _call_impl(self, *input, **kwargs)
   1192         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194             return forward_call(*input, **kwargs)
   1195         # Do not call functions when jit is used
   1196         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.9/dist-packages/torch/nn/modules/ in forward(self, input)
    159     def forward(self, input: Tensor) -> Tensor:
--> 160         return F.embedding(
    161             input, self.weight, self.padding_idx, self.max_norm,
    162             self.norm_type, self.scale_grad_by_freq, self.sparse)

/usr/local/lib/python3.9/dist-packages/torch/nn/ in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
   2208         # remove once script supports set_grad_enabled
   2209         _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 2210     return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)

IndexError: index out of range in self
ydshieh commented 1 year ago

Hi @amariucaitheodor . Thank you for reporting the issue!

Could you also copy-paste the error (traceback) you got to your above PR description? Thanks.

apsdehal commented 1 year ago

I tried the colab and found the issue. Specifically, the code which is used for calculating input_ids and input_ids_masked is incorrect as the torch_mask_tokens function returns modified input_ids with masking and the corresponding labels. Since the loss is only calculated on the masked tokens, other tokens are set to -100 in the labels. This causes an "index out of range" error down the line in the embeddings' forward.

amariucaitheodor commented 1 year ago

Thank you for the reply! I had noticed the same problem. What is then the correct way of calculating input_ids_masked? The code doesn't work with DataCollatorForLanguageModeling for the reasons mentioned above, and there is no other example for doing this.

ydshieh commented 1 year ago

Thank you @amariucaitheodor for providing the error log, and thanks @apsdehal for sharing your finding. I will take a look on this issue. But @apsdehal , don't hesitate to share if you have any idea regarding the correct solution ❤️

ydshieh commented 1 year ago

Hello! After looking into the issue with the notebook, here is my finding:

The solution is just to prepare the correct inputs for the model:

inputs['input_ids_masked'], _ = data_collator.torch_mask_tokens(

With this change, I get loss: 7.162976264953613.

Let me know if you have further question 🤗

apsdehal commented 1 year ago

@ydshieh I don't think this is also correct as torch_mask_tokens masks the input_ids in place so you will have to clone the input_ids before passing them to it.

ydshieh commented 1 year ago

@apsdehal Thanks a lot, nice catch! You are 100% correct. @amariucaitheodor Please see this comment too!

ydshieh commented 1 year ago

As it turns out that this is not an issue in modeling code in transformers, but the wrong preparation of model inputs, I move forward to close the issue.

@amariucaitheodor If you still have issues, you can post on Hugging Face Forums.

However, if you find other issue(s) you believe that is/are in modeling code, feel free to continue to leave comments here.