huggingface / tokenizers

💥 Fast State-of-the-Art Tokenizers optimized for Research and Production
https://huggingface.co/docs/tokenizers
Apache License 2.0
9k stars 794 forks source link

Inconsistent behaviour of `PreTrainedTokenizerFast`s on diacritics marked texts #1663

Open sven-nm opened 1 week ago

sven-nm commented 1 week ago

System Info

Who can help?

@ArthurZucker @itazap

Information

Tasks

Reproduction

BatchEncoding.encodings[0].word_ids has alignment errors when working with diacritics (i.e. special accents). Here is a minimal working example:

from typing import List
import transformers
import unicodedata

# Instanciate the PreTrainedTokenizerFast
model_name = 'FacebookAI/xlm-roberta-base'
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name, add_prefix_space=False)

# Example sentences
sentences: List[str] = [
    """And Odys- seus in 1. 1367, without caring to resent the sneer, simply reaffirms his right to take a line of his own, and pleads the reasonableness of his trying to win those in authority over to his side. On which Agamemnon (1. 1358) throws the entire responsibility on Odysseus, and Odysseus says (1. 1369), ‘ That makes no differ ence. Your consent, in whatever terms it is granted, will be equally kind.” If this is rejected, 1. 1366 must refer not to Odysseus’ words, but merely to his attitude of dissent. 1. 1367 is thus less pointed. For the meaning given to ἐνθάδ᾽ ἵξομαι, l. 136%, cp. Eur. Androm. 342, ἀλλ’ εἶσιν of xpf,—and for ὡς dv, 1. 1369, cp. O. C. 1361, and note. 1371. σοὶ μέν, ker A] For this un- gracious expression, cp. O. T. 671, 2, τό γὰρ σόν, οὐ τὸ τοῦδ᾽, ἐποικτείρω στόμα | ἐλεινόν, οὗτος δ᾽, ἔνθ᾽ ἂν 7, στυγήσεται. 1372. κἀκεῖ κἀνθάδ᾽ | E.on 1,. 841.}.γ8. 1373. σοὶ δὲ. ἃ Ἐχρή.] ‘You may do what you must:’ an ill-humoured way of saying, ‘Do as you please.” χρή, although rejected by Dindgrf and others in favour of χρῇς, i.e χρήζεις, is not inexpressive,and is possibly right. Cp. El. 606.—Exit Agamemnon. 1375. τοιοῦτον ὄντα] ‘While you act in this way. Cp. Phil. 1049, οὗ γὰρ τοιούτων δεῖ, τοιοῦτός εἰμ᾽ ἐγώ,""",
    # """Hello, this is a long sentence in ASCII with a lot of words, and it should be tokenized correctly 1234 how do you feel ? Hello, my name Foo, I'm a friend of Bar.""",

]

# Convert to NFC to make sure there is no floating combining character
sentences = [unicodedata.normalize('NFC', s) for s in sentences]

# Let's start with a working scenario. Here, I pre-tokenize inputs my self, with a blunt
# split. After that, we run the tokenizer and compare the maximum index in the 
# `BatchEncoding.encodings[0].word_ids`... It should be equal to the length of the input -1. 
sentences_pretokenized: List[List[str]] = [s.split() for s in sentences]

batch_encoding = tokenizer(sentences_pretokenized, # ⚠️ Using the pretokenized sentences (List[List[str]])
                           padding=True,
                           truncation=True,
                           max_length=tokenizer.model_max_length,
                           pad_to_multiple_of=tokenizer.model_max_length,
                           add_special_tokens=True,
                           is_split_into_words=True) # ⚠️ Setting this to True

max_word_id = max([word_id for word_id in batch_encoding.encodings[0].word_ids if word_id is not None])
number_of_words = len(sentences_pretokenized[0])

print(f"Max word_id: {max_word_id}") # 225 ✅
print(f"Real number of words: {number_of_words}") # 226

# Good, this is what we were hoping to see. Alignment is correct. However, let's look at what 
# happens if I pass the sentences directly, as Tokenizer should accept them: 
batch_encoding = tokenizer(sentences, # ⚠️ Using the raw sentences (List[str])
                           padding=True,
                           truncation=True,
                           max_length=tokenizer.model_max_length,
                           pad_to_multiple_of=tokenizer.model_max_length,
                           add_special_tokens=True,
                           is_split_into_words=False) # ⚠️ Setting this to False (default, but explicit for clarity)

max_word_id = max([word_id for word_id in batch_encoding.encodings[0].word_ids if word_id is not None])
number_of_words = len(sentences_pretokenized[0])

print(f"Max word_id: {max_word_id}") # 231 ❌  WRONG! 
print(f"Real number of words: {number_of_words}") # 226

# Now let us see where the alignment starts to mismatch: 
for word_id, token in zip(batch_encoding.encodings[0].word_ids, batch_encoding.encodings[0].tokens):
    if word_id is None:
        print(f"Token: {token}")
        continue
    try:
        print(f"Word: {sentences_pretokenized[0][word_id]},\t\tToken: {token}")
    except:
        print("-------ERROR-------")
        print(f"Token: {token}")

# .....
# Word: ἐνθάδ᾽,     Token: ▁ἐ
# Word: ἐνθάδ᾽,     Token: ν
# Word: ἐνθάδ᾽,     Token: θ
# Word: ἐνθάδ᾽,     Token: άδ
# Word: ἵξομαι,,        Token: ▁
# Word: ἵξομαι,,        Token: ̓  # <--- This is a combining diacritic seems to be causing the issue
# Word: l.,     Token: ▁
# Word: l.,     Token: ἵ
# Word: l.,     Token: ξ
# Word: l.,     Token: ομαι

Expected behavior

NOTE. I am aware that similar problem have been raised (https://github.com/huggingface/transformers/issues/9637), and that the problem also exists with other models, even with ASCII-only examples (e.g. setting model_name to bert-base-uncased and using the second example only), but FacebookAI/xlm-roberta-base produces seemless alignment with ASCII chars.

I think there should be at least a warning as misalignments can have dramatic downstream consequences (thinking notably of token classification tasks).

ArthurZucker commented 2 days ago

Hey! Pretty sure this is related to the tokenizers library directly. I don't have time to investigate as of right now, hope someone can help! 🤗