Closed liweifriends126 closed 2 years ago
I can reproduce it for this particular text, there indeed is some difference between fast and slow tokenizer.
from transformers import XLMRobertaTokenizer, AutoTokenizer
tok_s = XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")
tok_f = AutoTokenizer.from_pretrained("xlm-roberta-base")
line = "スナップリング SC-40"
tok_s(line).input_ids
# [0, 6, 3385, 17456, 46405, 76931, 17715, 41734, 2]
tok_f(line).input_ids
# [0, 6, 3385, 17456, 13451, 17462, 113810, 75862, 246514, 17715, 41734, 2]
cc @SaulLu
I was going to open a new issue, but it seems it may be related to this.
I am wondering if this is expected behavior? U+FF08 "(" and U+0028 "(" both encode to [0, 15, 2] using XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")
.
CC: @patil-suraj
Hi @liweifriends126,
Thank you for bringing this problem to our attention! It is indeed a problem that the sequence of ids are not identical!
Investigating the problem, I think the problem lies in the encoding of u"\u30d5\u309a" (プ) and u"\u30af\u3099" (グ). Let me share with you my little test bellow:
# Define texts to compare
text_1 = u"\u30d5\u309a" # プ
text_2 = u"\u30af\u3099" # グ
# Installations
!pip install transformers
!pip install sentencepiece
!git clone https://github.com/pytorch/fairseq
!cd fairseq
!pip install .
!wget https://dl.fbaipublicfiles.com/fairseq/models/xlmr.base.tar.gz
!tar -xzvf xlmr.base.tar.gz
# Load the model in fairseq
from fairseq.models.roberta import XLMRModel
xlmr = XLMRModel.from_pretrained('/content/data/xlmr.base', checkpoint_file='model.pt')
xlmr.eval()
# Load the model in transformers
from transformers import AutoTokenizer
tokenizer_f = AutoTokenizer.from_pretrained('xlm-roberta-base')
# Compare encoding
def compare(text):
faiseq_input_ids = xlmr.encode(text).tolist()
faiseq_ids_to_tokens = [xlmr.decode(torch.tensor([id])) for id in faiseq_input_ids]
faiseq_ids_to_tokens_unicode = [tok.encode('raw_unicode_escape') for tok in faiseq_ids_to_tokens]
trfs_input_ids = tokenizer_f.encode(text)
trfs_ids_to_tokens = tokenizer_f.convert_ids_to_tokens(trfs_input_ids)
trfs_ids_to_tokens_unicode = [tok.encode('raw_unicode_escape') for tok in trfs_ids_to_tokens]
print(f"{'Version':8}|{'Input ids':24}|{'Corresponding tokens':30}|Corresponding tokens in unicode format")
print(f"{'fairseq':8}|{repr(faiseq_input_ids):24}|{repr(faiseq_ids_to_tokens):30}|{repr(faiseq_ids_to_tokens_unicode)}")
print(f"{'trfs':8}|{repr(trfs_input_ids):24}|{repr(trfs_ids_to_tokens):30}|{repr(trfs_ids_to_tokens_unicode)}")
compare(text_1)
# Version |Input ids |Corresponding tokens |Corresponding tokens in unicode format
# fairseq |[0, 6, 16985, 2] |['', '', 'プ', ''] |[b'', b'', b'\\u30d7', b'']
# trfs |[0, 6, 17462, 113810, 2]|['<s>', '▁', 'フ', '゚', '</s>']|[b'<s>', b'\\u2581', b'\\u30d5', b'\\u309a', b'</s>']
compare(text_2)
# Version |Input ids |Corresponding tokens |Corresponding tokens in unicode format
# fairseq |[0, 6, 21300, 2] |['', '', 'グ', ''] |[b'', b'', b'\\u30b0', b'']
# trfs |[0, 6, 4758, 246514, 2] |['<s>', '▁', 'ク', '゙', '</s>']|[b'<s>', b'\\u2581', b'\\u30af', b'\\u3099', b'</s>']
What is surprising about this test is that sentencepiece transforms the \u30d5\u309a
sequence into the composed u30d7
version (same for \u30af\u3099
). The resulting character in both cases is identical but the unicode encoding is different: this is a problem for the consistency of the input for our model.
The bad news is that I don't know how sentencepiece manages to change a decomposed character into a composed character becaused we are taking the normalization operation directly from the sentencepiece proto model.
Let me ping @Narsil who may have an idea of where the difference lies.
BTW: If you really need a fast implementation with better parity, maybe I can provide one after agreed by my manager.
If you ever have time, I think this is indeed an important bug but most probably hard to solve! Thanks a lot for offering your help :pray:
Hi @kristjanArumae,
The case you report is in my opinion well expected because the encoding is identical between the code base of the authors of xlm-r and the fast implementation in transformers. This is a normalization of the text selected by the authors.
By reusing the functions defined in my previous comment, we can check that:
def compare_ids(text):
faiseq_input_ids = xlmr.encode(text).tolist()
trfs_input_ids = tokenizer_f.encode(text)
print(f"{'Version':8}|Input ids")
print(f"{'fairseq':8}|{repr(faiseq_input_ids)}")
print(f"{'trfs':8}|{repr(trfs_input_ids):}")
compare_ids(text_3)
# Version |Input ids
# fairseq |[0, 15, 2]
# trfs |[0, 15, 2]
compare_ids(text_4)
# Version |Input ids
# fairseq |[0, 15, 2]
# trfs |[0, 15, 2]
Hi @SaulLu:
Thanks for the response. For the normalization logic, I think you can check the following file:
https://raw.githubusercontent.com/google/sentencepiece/master/data/nmt_nfkc.tsv
This file defines the normalization logic. For example, for the following line:
41 302 300 1EA6 # Ầ => Ầ
It means that if "41 302 300" is encountered, it should be replaced to "1EA6"
Thanks
Hello everyone.
When the fast tokenizer was implemented, extreme care was taken that there was no divergence between algorithm. To the exception of "AAA" -> ("AA", "A") vs ("A¨, "AA"), since both are valid, have the same score, and it's up to a float calculation divergence between both code bases (f64 vs f32).
The check was running ALL spm tokenizers, against the entire XNLI database (Seemed not too big ,yet provide ample amount of weird unicode oddities to be a good testing ground).
That doesn't mean something didn't work properly/ wasn't check against.
Here is the code to replicate the spm_precompiled
code: https://github.com/huggingface/spm_precompiled
And here is how it's tied to the tokenizer: https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/normalizers/precompiled.rs
One definite potential suspect is the highly suspicious code defined here: https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/normalizers/precompiled.rs#L46
As mentionned by past me, this code is super odd, but it seemed to really operate that way at the time. Purely respecting the Trie on bytes didn't work, working only full graphemes neither, unfortunately I don't remember all specifics. It could have been bad implementations on my end when trying those solutions. (Offsets are also a source of headaches for this code)
IIRC all of the issues encountered were engraved in tests.
The bad grapheme here does seem to be of length 6
which magically makes it not being respected in our code: https://play.rust-lang.org/?version=stable&mode=debug&edition=2021&gist=562ab464479995b315bcc585c24b2e0a
I think we have to trust my past self, attempt to find the better code, but have a huge amount of testing to make sure we don't break anything else.
I also looked sentencepiece
itself has made some modifications on those files. They shouldn't have modified anything on the surface, but maybe something to keep in mind for this https://github.com/google/sentencepiece/commit/fab966ad218c6d3449f7ebf088c8b891afbabec2
There's also a lot of details in the PR that originated this: https://github.com/huggingface/tokenizers/pull/401
Part of the explanation over there predates the Precompiled
code, as I was attempting to use our own normalizers
as first rules. Precompiled should match spm
1-1 (It's just rewritten in Rust, but it's really a big copy paste).
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
Environment info
transformers
version: 4.16.2Who can help
LysandreJik & SaulLu
Information
Model I am using (Bert, XLNet ...): xlm-roberta-base
The problem arises when using:
The tasks I am working on is:
To reproduce
Steps to reproduce the behavior:
Following is the code I use to run the XLM-Roberta tokenizer:
And following is hugging face's output:
But when I run the same query with google's sentence-piece (Note that saved the same query into a file then use cat to send it to google's encoder):
The result is not same. And even if I considered about the fairseq map mentioned at huggingface
The output is still not match. Basically speaking, when considering about the map, google's output corresponds to:
which not matches with huggingface's output
I think huggingface uses a fast implementation for the tokenization, but the fast implementation contains bugs in it.
BTW: If you really need a fast implementation with better parity, maybe I can provide one after agreed by my manager.
Expected behavior
Huggingface's output should be {'input_ids': [0, 6 3385 17456 46405 76931 17715 41734, 2] ...}