huggingface / tokenizers

💥 Fast State-of-the-Art Tokenizers optimized for Research and Production
https://huggingface.co/docs/tokenizers
Apache License 2.0
8.69k stars 747 forks source link

added_tokens with bytemap charaters in ByteLevel could not be decoded correctly #1392

Open DOGEwbx opened 7 months ago

DOGEwbx commented 7 months ago

I just found that if added tokens contain some characters that exist in the byte map for ByteLevel preprocessor could not be decoded correctly. This is a script to reproduce the problem with version 0.14.1

from tokenizers import Tokenizer
from tokenizers import normalizers
from tokenizers.pre_tokenizers import (
    ByteLevel,
)
from tokenizers.models import BPE
from tokenizers import decoders
tokenizer = Tokenizer(BPE())
tokenizer.normalizer = normalizers.Sequence([])

tokenizer.pre_tokenizer = Sequence(
    [
        ByteLevel(add_prefix_space=False, use_regex=False),
    ])
tokenizer.add_tokens(["ilÖveyou"])
# Ö is the character representing for 0xf6
tokenizer.decoder = decoders.ByteLevel()
encode_result = tokenizer.encode("ilÖveyou")
print(encode_result.ids)
print(tokenizer.decode(encode_result.ids))

the output wil be

[0]
il�veyou

I believe the problem comes from https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/tokenizer/mod.rs#L832-L836 I don't think added token should be sent to bytelevel decoder for it is extacted before pretokenize.

ArthurZucker commented 7 months ago

Hey! Thanks for the report, the output is the same if you use tokenizers==0.13.3 so unrelated to the 0.14.1 release. I think one solution would be to have a normalizer that does ByteLevel, and set normalized to false for the tokens. Not sure we have an other solution for now.

AddedTokens usually have to be sent to the decoder, because the pre-tokenizatio is applied to them.

DOGEwbx commented 7 months ago

is it possible to let added_tokens_map_r of AddedVocabulary store the mapping of id to tokens after pre-tokenization? so that it can generate the same output after decoder?

ArthurZucker commented 7 months ago

It's not possible no, normalizers are use for that purpose however. And you can also add the token like this: Since the text is first splitted and normalized, and then pre_tokenized, adding the pre-tokenized version of the token to the added_tokens_map_r will still not work. In this case you could use a StripAccents normalizer to make sure accents are stripped, but this will also affect all the other tokens.

DOGEwbx commented 7 months ago

Ah I see. Thanks for your explaination. Is there any planning on solving this bug?

github-actions[bot] commented 6 months ago

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.

ArthurZucker commented 6 months ago

A pr for a fix that is backward compatible is welcome! Otherwise I won't have time to dive in this 🤗

github-actions[bot] commented 4 months ago

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.

ArthurZucker commented 2 months ago

https://github.com/meta-llama/llama3/issues/67#issuecomment-2072952721 TLDR this should help:

>>> from tokenizers import AddedToken, pre_tokenizers
>>> from transformers import AutoTokenizer
>>> pre_tokenizers.ByteLevel(False,False).pre_tokenize_str("Bác")
[('Bác', (0, 3))]
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
>>> tokenizer.add_tokens(AddedToken("Bác", normalized=False,special=False))
>>> tokenizer.decode(tokenizer.encode("Bác"))
'<|begin_of_text|>Bác'
ArthurZucker commented 4 weeks ago

Re-opening as the merge on main will be reverted for a better fix soon