UKPLab / sentence-transformers

Multilingual Sentence & Image Embeddings with BERT
https://www.SBERT.net
Apache License 2.0
14.56k stars 2.41k forks source link

TypeError: T5EncoderModel.forward() got an unexpected keyword argument 'token_type_ids' #2588

Open atasoglu opened 3 months ago

atasoglu commented 3 months ago

Hi,

I am trying to use boun-tabi-LMG/TURNA, a Turkish T5 model, with sentence-transformers as it has been specifically pre-trained for Turkish.

While trying with the code snippet below, I encountered a TypeError as I shared below.

from sentence_transformers import models, SentenceTransformer
t5_model = models.Transformer("boun-tabi-LMG/TURNA")
pooling_model = models.Pooling(t5_model.get_word_embedding_dimension(), pooling_mode="mean")
model = SentenceTransformer(modules=[t5_model, pooling_model])
model.encode(["Merhaba dünya!"])

Out:

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
[<ipython-input-4-e5957cf51002>](https://localhost:8080/#) in <cell line: 5>()
      3 pooling_model = models.Pooling(t5_model.get_word_embedding_dimension(), pooling_mode="mean")
      4 model = SentenceTransformer(modules=[t5_model, pooling_model])
----> 5 model.encode(["Merhaba dünya!"])

6 frames
[/usr/local/lib/python3.10/dist-packages/sentence_transformers/SentenceTransformer.py](https://localhost:8080/#) in encode(self, sentences, prompt_name, prompt, batch_size, show_progress_bar, output_value, precision, convert_to_numpy, convert_to_tensor, device, normalize_embeddings)
    355 
    356             with torch.no_grad():
--> 357                 out_features = self.forward(features)
    358 
    359                 if output_value == "token_embeddings":

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/container.py](https://localhost:8080/#) in forward(self, input)
    215     def forward(self, input):
    216         for module in self:
--> 217             input = module(input)
    218         return input
    219 

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1509             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510         else:
-> 1511             return self._call_impl(*args, **kwargs)
   1512 
   1513     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1518                 or _global_backward_pre_hooks or _global_backward_hooks
   1519                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520             return forward_call(*args, **kwargs)
   1521 
   1522         try:

[/usr/local/lib/python3.10/dist-packages/sentence_transformers/models/Transformer.py](https://localhost:8080/#) in forward(self, features)
     96             trans_features["token_type_ids"] = features["token_type_ids"]
     97 
---> 98         output_states = self.auto_model(**trans_features, return_dict=False)
     99         output_tokens = output_states[0]
    100 

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1509             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510         else:
-> 1511             return self._call_impl(*args, **kwargs)
   1512 
   1513     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1518                 or _global_backward_pre_hooks or _global_backward_hooks
   1519                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520             return forward_call(*args, **kwargs)
   1521 
   1522         try:

TypeError: T5EncoderModel.forward() got an unexpected keyword argument 'token_type_ids'

Thank you in advance for your assistance and guidance!

tomaarsen commented 3 months ago

Hello!

I believe this is a configuration issue on the side of boun-tabi-LMG/TURNA. Their tokenizer returns a token_type_ids, when it really should not, as the model seems to not use them. Sentence Transformers assumes that if the tokenizer returns token_type_ids, it's because the model requires it, so it's passed to the model.

See e.g. the following script:

from transformers import AutoModel, AutoTokenizer

model = AutoModel.from_pretrained("boun-tabi-LMG/TURNA")
tokenizer = AutoTokenizer.from_pretrained("boun-tabi-LMG/TURNA")

inputs = tokenizer("Merhaba dünya!", return_tensors="pt")
outputs = model(**inputs)
print(outputs.last_hidden_state.shape)

This also returns:

TypeError: T5Model.forward() got an unexpected keyword argument 'token_type_ids'

I suspect this is because the configured tokenizer class here is PreTrainedTokenizerFast, and not e.g. T5TokenizerFast. The former seems to assume that the model has token_type_ids as one of the model inputs: https://github.com/huggingface/transformers/blob/0bd58f1ce0573c0e3269de4215a17d318add49b9/src/transformers/tokenization_utils_base.py#L1561

So, the patch is as follows:

from sentence_transformers import models, SentenceTransformer

t5_model = models.Transformer("boun-tabi-LMG/TURNA")
pooling_model = models.Pooling(t5_model.get_word_embedding_dimension(), pooling_mode="mean")
model = SentenceTransformer(modules=[t5_model, pooling_model])
# Remove token_type_ids from the tokenizer's model input names, as the model does not use it
model.tokenizer.model_input_names.remove("token_type_ids")

embeddings = model.encode(["Merhaba dünya!"])
print(embeddings.shape)
(1, 1024)

And now you can use the model or finetune it as normal. Hope this helps. You can also open a discussion at https://huggingface.co/boun-tabi-LMG/TURNA that the model_input_names for their tokenizer might not be configured well, or that they might want to change the tokenizer class (e.g. T5TokenizerFast has the correct model_input_names here)

atasoglu commented 3 months ago

It worked! Thank you very much for your detailed answer and thoughtful advice on the tokenizer!