keras-team / keras-nlp

Modular Natural Language Processing workflows with Keras
Apache License 2.0
755 stars 223 forks source link

No tokenizer option to add special tokens (`[MASK]`, `[PAD]`) inside string input #1395

Open admiraltidebringer opened 7 months ago

admiraltidebringer commented 7 months ago

Describe the bug Bert Tokenizer can't tokenize [MASK] token. it should return 103. but it returns 1031, 7308, 1033.

Proof keras_nlp library: keras_nlp.models.BertTokenizer.from_preset('bert_tiny_en_uncased', sequence_length=12)(['i am going to [MASK] to study math', 'the day before we went to [MASK] to cure illness']) result: <tf.Tensor: shape=(2, 12), dtype=int32, numpy= array([[1045, 2572, 2183, 2000, 1031, 7308, 1033, 2000, 2817, 8785, 0, 0], [1996, 2154, 2077, 2057, 2253, 2000, 1031, 7308, 1033, 2000, 9526, 7355]], dtype=int32)>

hugging face: from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("Serdarmuhammet/bert-base-banking77") tokenizer(['i am going to [MASK] to study math', 'the day before we went to [MASK] to cure illness'], return_tensors='tf', max_length=20, padding=True)['input_ids'] result: <tf.Tensor: shape=(2, 12), dtype=int32, numpy= array([[ 101, 1045, 2572, 2183, 2000, 103, 2000, 2817, 8785, 102, 0, 0], [ 101, 1996, 2154, 2077, 2057, 2253, 2000, 103, 2000, 9526, 7355, 102]], dtype=int32)>

abuelnasr0 commented 7 months ago

use keras_nlp.models.BertMaskedLMPreprocessor and it will mask the input for you. https://keras.io/api/keras_nlp/models/bert/bert_masked_lm_preprocessor/

admiraltidebringer commented 7 months ago

i used that too. it doesn't work corect too. keras_nlp.models.BertMaskedLMPreprocessor returned 1031, 7308, 1033 for [MASK] too.

abuelnasr0 commented 7 months ago

@admiraltidebringer don't put [MASK] in your text. Just pass the original text without any changes to BertMaskedLMPreprocessor and it will substitute 10% of the tokens by the [MASK] token automatically. also it will produce the labels for BertMaskedLM

admiraltidebringer commented 7 months ago

i understand what are you talking about. but a tokenizer must do this and encode [MASK] to 103. you can check this in pytorch and hugging face models. so this need to be modified.

abuelnasr0 commented 7 months ago

@admiraltidebringer It will be a good feature, if it got added of course.

admiraltidebringer commented 7 months ago

i think it should have been added, it's a default feature of tokenizers.

mattdangerw commented 7 months ago

Agreed this would be a great feature to add!

Note that it is probably not right to blanket force this on all users. Let's say you are are fine-tuning on an article describing tokenization techniques (actually not unreasonable in today's world to expect this). You don't want external text from a training source to tokenize to a control character like this. Same for [START], [END], etc...

But for things like an example on keras.io, this would be super useful.

@abuelnasr0 is correct on the current best practice here https://github.com/keras-team/keras-nlp/issues/1395#issuecomment-1881038059, add control characters via modifying the output of a tokenizer, not by modifying the input string (this is probably the safer thing for production systems anyway).

Adding this would be welcome, but tricky because it would require updating tf-text, c++ ops and waiting a while for that to propogate to a tf release (roughly every 3 months). If we add it it should be gated by an option (you can turn it on during tokenizer construction).

abuelnasr0 commented 7 months ago

@mattdangerw I can add this feature for word_piece_tokenizer. I will open a PR for it

mattdangerw commented 5 months ago

@abuelnasr0 I think there is an issue with https://github.com/keras-team/keras-nlp/pull/1397/, when lowercase is true, we will lowercase all the special tokens so they won't be preserved. Given that most word piece special tokens happen to use uppercase e.g. [CLS], [PAD], etc, seems like this might affect a lot of real world usage.

Is this indeed a bug, and if so, any idea on how to fix?

abuelnasr0 commented 4 months ago

@mattdangerw this indeed a bug. I have opened this PR keras-team/keras-nlp#1543 to fix it. Thanks for reporting it and I am sorry for not noticing it in my first PR.