huggingface / tokenizers

💥 Fast State-of-the-Art Tokenizers optimized for Research and Production
https://huggingface.co/docs/tokenizers
Apache License 2.0
8.93k stars 777 forks source link

Possible bug in case of prepending chars in a pretokenizer #1423

Closed ivankrylatskoe closed 4 months ago

ivankrylatskoe commented 9 months ago

Please, consider the following cases and give your opinion, if it is a bug or not.

Base setup

from tokenizers import models, pre_tokenizers, trainers, Tokenizer
tokenizer = Tokenizer(models.BPE())
tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(add_prefix_space=True, replacement="X")
trainer = trainers.BpeTrainer(vocab_size=6, special_tokens=["<s>"])
tokenizer.train_from_iterator(["<s>a"], trainer=trainer)
print(tokenizer.get_vocab())

Output: {'a': 4, '<s>': 0, '<': 1, '>': 2, 'X': 3, 's': 5}

Alternatively, you may get the same tokenizer in the following way:

tokenizer = Tokenizer.from_str(
"""
{
    "version":"1.0",
    "added_tokens":[
        {"id":0,"content":"<s>","single_word":false,"lstrip":false,"rstrip":false,"normalized":false,"special":true}
    ],
    "pre_tokenizer":{
        "type":"Metaspace","replacement":"X","add_prefix_space":true
    },
    "model":{
        "type":"BPE",
        "vocab":{
            "<s>":0,
            "<":1,
            ">":2,
            "X":3,
            "a":4,
            "s":5
        },
        "merges":[]
    }
}
""")

Case 1

First, let's check Metaspace.

print(tokenizer.pre_tokenizer.pre_tokenize_str('<s>a'))

Output: [('X<s>a', (0, 4))] This is ok. Pretokenizer added X in the beginning of the text.

print(tokenizer.encode('<s>a').tokens)

Output: ['<s>', 'X', 'a'] Why tokens order is reversed???

Case 2

Now let's check our custom pre-tokenizer.

from tokenizers import NormalizedString, PreTokenizedString
from typing import List
class CustomPreTokenizer:

    def __init__(self, prefix):
        self.prefix = prefix

    def add_prefix(self, i: int, normalized_string: NormalizedString) -> List[NormalizedString]:
        normalized_string.prepend(self.prefix)
        return [normalized_string]

    def pre_tokenize(self, pretokenized: PreTokenizedString): 
        pretokenized.split(self.add_prefix)

tokenizer.pre_tokenizer = pre_tokenizers.PreTokenizer.custom(CustomPreTokenizer("X"))
print(tokenizer.pre_tokenizer.pre_tokenize_str('<s>a'))

Output: [('X<s>a', (0, 4))] We get the same result as in Case 1. It's ok.

print(tokenizer.encode('<s>a').tokens)

Output: ['<s>', 'X', 'a'] We get the same result as in Case 1. Is it not ok?

Case 3

Adding more than one character.

tokenizer.pre_tokenizer = pre_tokenizers.PreTokenizer.custom(CustomPreTokenizer("XXXXXXX"))
print(tokenizer.pre_tokenizer.pre_tokenize_str('<s>a'))

Output: [('XXXXXXX<s>a', (0, 4))] Ok.

print(tokenizer.encode('<s>a').tokens)

Output: ['<s>', 'X', 'X', 'X', 'X', 'X', 'X', 'X', 'a'] Now it's a long jump. Why?

Case 4

Adding special token.

tokenizer.pre_tokenizer = pre_tokenizers.PreTokenizer.custom(CustomPreTokenizer("<s>"))
print(tokenizer.pre_tokenizer.pre_tokenize_str('<s>a'))

Output: [('<s><s>a', (0, 4))] Ok.

print(tokenizer.encode('<s>a').tokens)

Output: ['<s>', '<', 's', '>', 'a'] Why tokens get tokenized in a different way???

print(tokenizer.pre_tokenizer.pre_tokenize_str('<s><s><s>aaaaa'))

Output: [('<s><s><s><s>aaaaa', (0, 14))] Ok.

print(tokenizer.encode('<s><s><s>aaaaa').tokens)

Output: ['<s>', '<s>', '<s>', '<', 's', '>', 'a', 'a', 'a', 'a', 'a'] Again, why different results for the same token?

Case 5

Adding several special tokens.

tokenizer.pre_tokenizer = pre_tokenizers.PreTokenizer.custom(CustomPreTokenizer("<s><s><s>"))
print(tokenizer.pre_tokenizer.pre_tokenize_str('<s><s>aaaaa'))

Output: [('<s><s><s><s><s>aaaaa', (0, 11))] Ok.

print(tokenizer.encode('<s><s>aaaaa').tokens)

Output: ['<s>', '<s>', '<', 's', '>', '<', 's', '>', '<', 's', '>', 'a', 'a', 'a', 'a', 'a'] Is it a mess??

Case 6

Let's check correctness of empty pretokenizer.

tokenizer.pre_tokenizer = pre_tokenizers.Sequence([])
print(tokenizer.pre_tokenizer.pre_tokenize_str('<s><s><s><s><s>aaaaa'))

Output: [('<s><s><s><s><s>aaaaa', (0, 20))] Ok.

print(tokenizer.encode('<s><s><s><s><s>aaaaa').tokens)

Output: ['<s>', '<s>', '<s>', '<s>', '<s>', 'a', 'a', 'a', 'a', 'a'] Finally, it's ok.

github-actions[bot] commented 8 months ago

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.

ivankrylatskoe commented 8 months ago

Ping

ArthurZucker commented 8 months ago

Sorry did not have a look. This is a bit strange, but I think the recent update to tokenizers will help you. You should set: prepend_scheme:'first' :

tokenizer = Tokenizer.from_str(
    """
    {
        "version": "1.0",
        "added_tokens": [
            {
                "id": 0,
                "content": "<s>",
                "single_word": false,
                "lstrip": false,
                "rstrip": false,
                "normalized": false,
                "special": true
            }
        ],
        "pre_tokenizer": {
            "type": "Metaspace",
            "replacement": "X",
            "add_prefix_space": true,
            "prepend_scheme": "first"
        },
        "model": {
            "type": "BPE",
            "vocab": {
                "<s>": 0,
                "<": 1,
                ">": 2,
                "X": 3,
                "a": 4,
                "s": 5
            },
            "merges": []
        }
    }
    """
)

result_tokens = tokenizer.encode("aaa<s><").tokens
print(result_tokens)
['X', 'a', 'a', 'a', '<s>', '<']
ivankrylatskoe commented 7 months ago

@ArthurZucker, hi! Thanks for your answer!

Please, check my example: print(tokenizer.encode('<s>a').tokens). With latest tokenizers and your tokenizer setup I still get strange result: ['<s>', 'X', 'a']

ivankrylatskoe commented 6 months ago

Ping

ivankrylatskoe commented 6 months ago

Still no solution

ArthurZucker commented 3 months ago

I am not getting this on 0.19:

image
ivankrylatskoe commented 3 months ago

Hi! Yes, encode works. But tokenizer.pre_tokenizer.pre_tokenize_str still doesn't work. So, the problem is not solved.

ArthurZucker commented 3 months ago

This is expected, the pre_tokenizer does not have access to the information about the special tokens, so it will always prepend regardless of whether the first token is a special token or not.

In [7]: tokenizer.pre_tokenizer.pre_tokenize_str("Xaaa")
Out[7]: [('Xaaa', (0, 4))]

as long as the prepend is not added twice, then it's working as expected I believe.