huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.45k stars 27.11k forks source link

FastTokenizer not using the user_defined_symbols defined in the SentencePiece Model #28324

Closed kitkhai closed 10 months ago

kitkhai commented 10 months ago

System Info

Who can help?

@ArthurZucker

Information

Tasks

Reproduction

from transformers.convert_slow_tokenizer import import_protobuf
from transformers import AutoTokenizer
from transformers import NllbTokenizer, NllbTokenizerFast

checkpoint = "facebook/nllb-200-distilled-600M"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
tokenizer.save_pretrained("old_tokenizer")

model_pb2 = import_protobuf()
m = model_pb2.ModelProto()
m.ParseFromString(open("./old_tokenizer/sentencepiece.bpe.model", 'rb').read())

piece = m.SentencePiece()
piece.piece = "superlongword"
piece.score = -10
piece.type = 4

m.pieces.extend([piece1])
with open("temp_eng_insert_user_def_sentencepiece.bpe.model", 'wb') as f:
    f.write(m.SerializeToString())

tokenizer_edited = NllbTokenizer(vocab_file="temp_sentencepiece.bpe.model", src_lang = "zho_Hans", tgt_lang = "eng_Latn")
tokenizer_edited_fast = NllbTokenizerFast(vocab_file="temp_sentencepiece.bpe.model", src_lang = "zho_Hans", tgt_lang = "eng_Latn")

sent = 'Hi there superlongword'
print(sent)
> Hi there superlongword

print("original tokenizer: ", tokenizer.tokenize(sent))
> original tokenizer:  ['▁Hi', '▁there', '▁super', 'long', 'word']

print("tokenizer with tokens: ", tokenizer_edited.tokenize(sent))
> tokenizer with tokens:  ['▁Hi', '▁there', '▁', 'superlongword']

print("tokenizer with tokens (Fast): ", tokenizer_edited_fast.tokenize(sent))
> tokenizer with tokens (Fast):  ['▁Hi', '▁there', '▁super', 'long', 'word']

Expected behavior

> Hi there superlongword
> original tokenizer:  ['▁Hi', '▁there', '▁super', 'long', 'word']
> tokenizer with tokens:  ['▁Hi', '▁there', '▁', 'superlongword']
> tokenizer with tokens (Fast):  ['▁Hi', '▁there', '▁', 'superlongword']

I faced a similar issue as raised by a question in the HF forum where the OP trainer the tokenizer with user_defined_symbols while in my case I added to the SentencePiece model file directly without training.

Noted that I can just use the add_tokens method to achieve the same outcome but because of another issue that I raised #28218 , I would like to avoid the use of add_tokens method if possible.

kitkhai commented 10 months ago

Additionally, is there a way to retrieve (and edit) the merge rules from "slow" & "fast" tokenizers respectively?

ArthurZucker commented 10 months ago

Hey! Few things here. What you are trying to do is outside the scope of the supported features. Adding a token should be done using tokenizer.add_tokens function. The fast version is for me more right than what you expect. If there are no merges, then there is absolutely no reason for the BPE model to fuse '▁super', 'long', 'word' into superlongword. Thus the slow version seems more wrong, and specifically because sentencepiece does not really allow adding tokens that way.