huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
129.52k stars 25.71k forks source link

AutoTokenizer: Phi-3 drops spaces when decodes a token at a time #31643

Open Andrei-Aksionov opened 3 weeks ago

Andrei-Aksionov commented 3 weeks ago

System Info

Who can help?

@ArthurZucker

Information

Tasks

Reproduction

from transformers import AutoTokenizer

phi_2_tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2")
phi_3_tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")

for name, tokenizer in (("phi-2", phi_2_tokenizer), ("phi-3", phi_3_tokenizer)):
    print(f"Tokenizer: {name}")
    tokens = tokenizer.encode("This is a test string")
    print(f"{tokens=}")
    print(tokenizer.decode(tokens))
    print("".join([tokenizer.decode(token) for token in tokens]))
    print("-" * 50)
Tokenizer: phi-2
tokens=[1212, 318, 257, 1332, 4731]
This is a test string
This is a test string
--------------------------------------------------
Tokenizer: phi-3
tokens=[1, 910, 338, 263, 1243, 1347]
<s> This is a test string
<s>Thisisateststring
--------------------------------------------------

Expected behavior

I expect that, even if I decode a single token at a time, the resulting string should contain spaces between tokens. As one can see, with Phi-2 model there are no problems, but for some reason Phi-3 does produce such a concatenated string.

ArthurZucker commented 3 weeks ago

cc @itazap

itazap commented 2 weeks ago

Hey @Andrei-Aksionov , thanks for the reproducer! It has to do with Phi-3 being based on the LlamaTokenizerFast and Phi-2 on CodeGen. LlamaTokenizerFast strips leading whitespace in order to manually add a prefix space on add_prefix_space. I'm looking into a fix now that handles this better!