Closed Ki-Seki closed 1 month ago
That's very interesting, and can confirm we have this issue.
gemma
would just error out if you pass an int and not a list, with no proper warning. While the fast works.
I think adding a test in the test_tokenization_common
will help know which models fails and which we have to update.
Yes, you're right. I added this test case in the test_tokenization_common
:
def test_single_id(self):
tokenizer = self.get_tokenizer()
rust_tokenizer = self.get_rust_tokenizer()
vocab_size = len(tokenizer)
int_single_id = vocab_size - 1
list_single_id = [vocab_size - 1]
self.assertEqual(tokenizer.decode(int_single_id), tokenizer.decode(list_single_id))
self.assertEqual(rust_tokenizer.decode(int_single_id), rust_tokenizer.decode(list_single_id))
The test results are as below (scroll to the bottom to view the failed 33 models):
```text
> self.assertEqual(tokenizer.decode(int_single_id), tokenizer.decode(list_single_id))
E AssertionError: 'l o w e s t' != 'lowest'
E - l o w e s t
E + lowest
tests/test_tokenization_common.py:4208: AssertionError
__________________ SqueezeBertTokenizationTest.test_single_id __________________
self =
feel free to open a PR for a fix. IMO we should not have spaces added in this case
No problem, I will try to do this, but there are some other research work that needs to be pushed forward recently, and I may do it later.
Hi :)
I'm pretty sure the issue is not how spaces_between_special_tokens
is used but that single tokens are split into letters here. To fix it, I'd suggest adding the following before iterating over the tokens:
if isinstance(filtered_tokens, str):
filtered_tokens = [filtered_tokens]
I ran a couple of the test cases that were reported to be failing above with a slightly modified version of the test function proposed by @Ki-Seki and they pass now
def test_single_id(self):
tokenizer = self.get_tokenizer()
vocab_size = len(tokenizer)
int_single_id = vocab_size - 1
list_single_id = [vocab_size - 1]
self.assertEqual(tokenizer.decode(int_single_id), tokenizer.decode(list_single_id))
if self.test_rust_tokenizer:
rust_tokenizer = self.get_rust_tokenizer()
self.assertEqual(rust_tokenizer.decode(int_single_id), rust_tokenizer.decode(list_single_id))
Unfortunately, I can't run all of the test cases (I keep running into weird python segmentation errors that occur even without having changed the library at all). Does know a trick how I can run the test cases anyway or is it ok if I create a pull request and wait for the CI tests?
You can create a PR and rely on the CIs for sure! 🤗
Hello @ArthurZucker and all, I don't think this is an issue related to the specific ids, but rather a general problem. I tested a bit on my local but to make sure my local setup isn't related, I tested on Colab:
Looks to me problem is that (i) a signature mismatch between PretrainedTokenizerBase
and PretrainedTokenizer
classes _decode
methods:
https://github.com/huggingface/transformers/blob/74b92c62560b7ade42d35a49f9063adc8b805c4a/src/transformers/tokenization_utils_base.py#L3913-L3915
https://github.com/huggingface/transformers/blob/74b92c62560b7ade42d35a49f9063adc8b805c4a/src/transformers/tokenization_utils.py#L1062-L1064
FastTokenizer
has this signature correctly:
https://github.com/huggingface/transformers/blob/74b92c62560b7ade42d35a49f9063adc8b805c4a/src/transformers/tokenization_utils_fast.py#L640-L642
Consequently slow tokenizer _decode
handles only list of ids, not a single id. If the filtered_tokens
is a single string, not a list of strings in the loop its characters are iterated and processed so @MariaHei is totally right:
https://github.com/huggingface/transformers/blob/74b92c62560b7ade42d35a49f9063adc8b805c4a/src/transformers/tokenization_utils.py#L1082
Also there are not many decoding tests, though lots of encoding tests :blush: I added quick signature fix and return statements, also added some decode tests in my PR.
System Info
transformers
version: 4.39.0.dev0Who can help?
@ArthurZucker
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Code
Output
Expected behavior
Consistent behaviors. For example, when decoding the single ID, the output could also be
##~
.Suspected rationale: In the
src/transformers/tokenization_utils.py
, the_decode
function incorrectly usesspaces_between_special_tokens
, and then adds spaces between the sub-tokens.