huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
133.92k stars 26.78k forks source link

Tokenizer `encode/decode` methods are inconsistent, TypeError: argument 'ids': 'list' object cannot be interpreted as an integer #28635

Open scruel opened 9 months ago

scruel commented 9 months ago

System Info

Who can help?

@ArthurZucker

Information

Tasks

Reproduction

Run the following code:

from transformers import AutoTokenizer

text = "test"
tokenizer = AutoTokenizer.from_pretrained("gpt2")
encoded = tokenizer.encode(text, return_tensors='pt')
result_text = tokenizer.decode(encoded, skip_special_tokens=True)
print(text)

Will raise exception:

Traceback (most recent call last):
  File "main.py", line 8, in <module>
    tokenizer.decode(encoded, skip_special_tokens=True)
  File "/home/scruel/mambaforge/envs/vae/lib/python3.11/site-packages/transformers/tokenization_utils_base.py", line 3748, in decode
    return self._decode(
           ^^^^^^^^^^^^^
  File "/home/scruel/mambaforge/envs/vae/lib/python3.11/site-packages/transformers/tokenization_utils_fast.py", line 625, in _decode
    text = self._tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: argument 'ids': 'list' object cannot be interpreted as an integer

Expected behavior

Should be able to print the original text "test", rather than raise an exception(TypeError).

scruel commented 9 months ago

https://github.com/huggingface/transformers/blob/83f9196cc44a612ef2bd5a0f721d08cb24885c1f/src/transformers/tokenization_utils_fast.py#L596-L605

Why only remove the leading batch axis while return tensor is None? I mean consider the annotation of text parameter of _encode_plus method, we won't need batch axis at all, so why not remove it for all return tensor types?

ArthurZucker commented 9 months ago

Fully agree with you. Encode can take batched, decode only a single. batch_decode is the one that support a batch of inputs. It has been that way for a WHILE now. I plan to deprecate some of this and simplify the API in favor of just encode decode that can support batches and singles.

Would you like to work on a fix for this? 🤗

scruel commented 9 months ago

Sure, already working on it, I will create a PR later today.

scruel commented 9 months ago

Hi @ArthurZucker,

ArthurZucker commented 8 months ago

Sorry for such a big change in the API I'd rather take care of it, I was more referring to a small PR that supports for now decoding a batch of inputs with decode!

About all the points you mentioned, for me the direction of the library is a bit different. Appart from removing calls to TF everywhere which indeed should be protected the same way as flax!

scruel commented 8 months ago

Sorry for such a big change in the API I'd rather take care of it, I was more referring to a small PR that supports for now decoding a batch of inputs with decode!

decode itself can't handle such tasks well, as I mentioned before:

it will be impossible to distinguish the difference between List[TextInput] and PreTokenizedInput with its own power without changing the logic or adding extra parameters

So yes, it becomes a "big change", since I already created a PR, you may take care of this based on it, no need to sorry 🤗

the direction of the library is a bit different

Can you explain this more? I think most of the points are just about the code style, so it won't affect the direction of anything, but may improve the maintainability :)

Appart from removing calls to TF everywhere which indeed should be protected the same way as flax!

Cool, I'm glad that you also think so, we'd better consider this more, consider having import statement for frequently used functions is definitely a bad idea coz the function has to look up those libraries every time when it gets called, even Python will only do one true import for one library (heavy operation, so we may also need to consider, when we should have the necessary true import).

github-actions[bot] commented 7 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.