huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.6k stars 27.14k forks source link

`Helsinki-NLP/opus-*` models `decode` not removing metaspace character #26018

Closed xenova closed 1 year ago

xenova commented 1 year ago

System Info

Who can help?

@ArthurZucker

Information

Tasks

Reproduction

Running

from transformers import AutoTokenizer
tokenizer=AutoTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-es')
tokenizer.decode(tokenizer("hello world")['input_ids'])

produces ▁hello▁world</s>.

tokenizer.decode(tokenizer("hello world")['input_ids'], skip_special_tokens=True)

produces ▁hello▁world

Expected behavior

The metaspace character () should be removed, and the returned string should be hello world</s> and hello world, respectively. This should be similar to:

from transformers import AutoTokenizer
tokenizer=AutoTokenizer.from_pretrained('facebook/nllb-200-distilled-600M')
tokenizer.decode(tokenizer("hello world")['input_ids'], skip_special_tokens=True)

which produces hello world

tanaymeh commented 1 year ago

I would like to work on fixing this, @xenova!

tanaymeh commented 1 year ago

@ArthurZucker I need some guidance here! I suppose this is not as simple as a regex replacement right? Should I contact the team members of Helsinki-NLP and get in touch with them for this or do you think there is a programmatical way to solve this?

ArthurZucker commented 1 year ago

Hey! Sure:

  1. Make sure that this is a bug, and taht the original tokenizer behaviour is not this one
  2. Look if this is only a fast issue (Meaning trying use_fast = False and check the outputs as well.
  3. Try to re-convert the model, maybe it was not correctly uploaded. Check the convert_slow_tokenizer.py in transformers to see the conversion. That is were you will find if add_prefix_space was used or not. Also check the normalizers and post_processors and decoders!

Cheers!

xenova commented 1 year ago
  1. I am quite sure this is a bug - I don't think it makes sense to keep these metaspace characters. See here for example (LlamaTokenizer removes it). You can also search the codebase for SPIECE_UNDERLINE and in each case when decoding it is removed. And this is not present for the MarianTokenizer (which is what these models use).
  2. From what I can tell there is no fast version of the tokenizer
    from transformers import AutoTokenizer
    tokenizer=AutoTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-es', use_fast=False)
    tokenizer.decode(tokenizer("hello world")['input_ids'])
    # outputs the same: '▁hello▁world</s>'
  3. See 2
tanaymeh commented 1 year ago

Thanks a lot, @xenova and @ArthurZucker for your comments!

From what I understand here, I need to change MarianTokenizer so that it removes metaspaces characters and then re-convert the Helsinki-NLP/opus-* models. Please correct me if I am wrong!

xenova commented 1 year ago

and then re-convert the Helsinki-NLP/opus-* models.

You shouldn't need to re-convert any models. The vocab.json, merges.txt, and tokenizer_config.json will also all stay the same.

All you should need to do is update MarianTokenizer to replace the with

tanaymeh commented 1 year ago

Got it, thanks @xenova! I used the same logic as LlamaTokenizer but now instead of ▁hello▁world as output, I get hello▁world which is still wrong.

Should I use string replacement or regex to remove the metaspace character instead?

xenova commented 1 year ago

You could probably just do something similar to this:

https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/models/mbart/tokenization_mbart.py#L304-L307

but here. e.g., return out_string.strip()return out_string.replace(SPIECE_UNDERLINE, " ").strip()

@ArthurZucker Is this good practice for sentencepiece tokenizers? From what I can tell, sp_model.decode_pieces is not used very often, so this decode block might be quite outdated itself.

tanaymeh commented 1 year ago

Thanks for the comment @xenova!

I did the following in my PR, is it acceptable too?

    def convert_tokens_to_string(self, tokens: List[str]) -> str:
        """Uses source spm if _decode_use_source_tokenizer is True, and target spm otherwise"""
        if tokens[0].startswith(SPIECE_UNDERLINE):
            tokens[0] = tokens[0][1:]

        # Other code in between

        out_string += sp_model.decode_pieces(current_sub_tokens)
        out_string = out_string.replace(SPIECE_UNDERLINE, " ")
        return out_string.strip()