huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.63k stars 26.92k forks source link

Check in PreTrainedTokenizer can cause incorrect tokenization #13038

Closed codedecde closed 3 years ago

codedecde commented 3 years ago

Environment info

Who can help

@LysandreJik

Information

This check in PreTrainedTokenizer can cause incorrect tokenization (and subsequent encoding) for space only sequences (or sequences with leading and trailing spaces). This can be problematic for byte only models (byT5 etc.), can cause inconsistent tokenizations between Tokenzer and TokenizerFast classes and can cause issues wherever the code assumes non-destructive behaviour of a tokenizer.

To reproduce

Steps to reproduce the behavior:

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("roberta-base", use_fast=False)
tokenizer_fast = AutoTokenizer.from_pretrained("roberta-base")
# Correct Tokenization
out = tokenizer_fast.tokenize(' ')
# The above results in ['Ä '], which is correct
# Incorrect Tokenization
out = tokenizer.tokenize(' ')
# The above results in [], which is incorrect

# Example 2.
assert ' ' == tokenizer.decode(tokenizer.encode(' ', add_special_tokens=False))  # This will fail, since '' != ' '

Expected behavior

Leading and trailing spaces should be considered during tokenization, especially for non-destructive tokenizers.

Proposed Solution

Changing the check from

if not text.strip():
    return []

To

if len(text) == 0:   # or if not text:
    return []

should be okay. Alternatively, having a flag (eg: remove_extra_whitespaces), and enabling the current behaviour only for the case when the flag is passed as True would also work.

LysandreJik commented 3 years ago

May be of interest to @SaulLu

SaulLu commented 3 years ago

Thank you very much for the detailed issue @codedecde !

This check had been integrated to solve non-deterministic tokenization problems and I think this solution had been retained because we did not see a use case at the time to tokenize a sentence containing only spaces (see issue and PR).

Could you please explain in which case you need to tokenize a sentence containing only a space? Thank you very much in advance!

codedecde commented 3 years ago

Hi @SaulLu. Thank you for responding, and really sorry for the late response. My use-case is a little niche. I am training byte level encoder models. In order to do the masking, I am using a BPE tokenizer with dropout, and remapping it back to the byte level. Eg:

tokenized = tokenizer.tokenize("Huggingface is awesome")
# ['Hug', 'ging', 'face', 'Ä ', 'is', 'Ä awesome']
inputs_with_mask, masked_tokens = mask_function(tokenized)
# ['Hug', 'ging', <mask>, **<mask>**, 'is', 'Ä awesome'], [<pad>, <pad>, 'face', **'Ä ',** <pad>, <pad>]
# The marked 'Ä ' token will get destroyed later because of the issue
decoded_text = byte_tokenizer.decode(inputs_with_mask)
# Hugging<mask><**mask>**is awesome
model_inputs, model_outputs = byte_tokenizer.encode(decoded_text, masked_tokens)
# ['H', 'u', 'g', 'g', 'i', 'n', 'g', <mask>, <mask>, <mask>, <mask>, **<mask>**, 'i', 's', ' ', 'a', 'w', 'e', 's', 'o', 'm', 'e']
# model_outputs = [<pad>,<pad>,<pad>,<pad>,<pad>,<pad>,<pad>, 'f', 'a', 'c', 'e', **''**, <pad>, ...]

In the above example, the mask inclosed between and its associated label are impacted by the problem mentioned. Since it is a niche use-case, having this as a kwarg flag enabled behaviour would be quite helpful (eg: by default, trailing and leading spaces are always stripped out, except when the flag is set to true ).

github-actions[bot] commented 3 years ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.