PretrainedTransformerTokenizer doesn't work with seq2seq dataset reader #4798

Closed JohnGiorgi closed 3 years ago

JohnGiorgi commented 3 years ago



As far as I can tell, PretrainedTransformerTokenizer is not compatible with the seq2seq dataset reader of allennlp-models when it is used as the source_tokenizer. The same error, contained in this try/except block here is triggered in multiple cases.

  1. When allennlp.common.util.START_SYMBOL and allennlp.common.util.END_SYMBOL are not in the pretrained transformers vocabulary. I was able to solve this in the config as follows:
    "dataset_reader": {
        "type": "copynet_seq2seq",
        "target_namespace": "target_tokens",
        "source_tokenizer": {
            "type": "pretrained_transformer",
            "model_name": "distilroberta-base",
            "tokenizer_kwargs": {
                "additional_special_tokens": {
                    "allennlp_start_symbol": "@start@",
                    "allennlp_end_symbol": "@end@",
  1. If PretrainedTransformerTokenizer.add_special_tokens is True (the default) for wordpiece-based tokenizers.
  2. For any BPE-based tokenizer I tried.

The error arises because there are more than two tokens in the list returned by self._source_tokeniser.tokenizer in the try/except block for all cases listed above:

    self._start_token, self._end_token = self._source_tokenizer.tokenize(
        start_symbol + " " + end_symbol
except ValueError:
    raise ValueError(
        f"Bad start or end symbol ('{start_symbol}', '{end_symbol}') "
        f"for tokenizer {self._source_tokenizer}"
Python traceback:

``` 2020-11-16 16:53:24,760 - CRITICAL - root - Uncaught exception Traceback (most recent call last): File "/project/6006286/johnmg/allennlp-models/allennlp_models/generation/dataset_readers/", line 98, in __init__ self._start_token, self._end_token = self._source_tokenizer.tokenize( ValueError: too many values to unpack (expected 2) During handling of the above exception, another exception occurred: Traceback (most recent call last): File "/home/johnmg/seq2rel/bin/allennlp", line 33, in sys.exit(load_entry_point('allennlp', 'console_scripts', 'allennlp')()) File "/project/6006286/johnmg/allennlp/allennlp/", line 34, in run main(prog="allennlp") File "/project/6006286/johnmg/allennlp/allennlp/commands/", line 118, in main args.func(args) File "/project/6006286/johnmg/allennlp/allennlp/commands/", line 110, in train_model_from_args train_model_from_file( File "/project/6006286/johnmg/allennlp/allennlp/commands/", line 170, in train_model_from_file return train_model( File "/project/6006286/johnmg/allennlp/allennlp/commands/", line 236, in train_model model = _train_worker( File "/project/6006286/johnmg/allennlp/allennlp/commands/", line 453, in _train_worker train_loop = TrainModel.from_params( File "/project/6006286/johnmg/allennlp/allennlp/common/", line 595, in from_params return retyped_subclass.from_params( File "/project/6006286/johnmg/allennlp/allennlp/common/", line 627, in from_params kwargs = create_kwargs(constructor_to_inspect, cls, params, **extras) File "/project/6006286/johnmg/allennlp/allennlp/common/", line 198, in create_kwargs constructed_arg = pop_and_construct_arg( File "/project/6006286/johnmg/allennlp/allennlp/common/", line 305, in pop_and_construct_arg return construct_arg(class_name, name, popped_params, annotation, default, **extras) File "/project/6006286/johnmg/allennlp/allennlp/common/", line 339, in construct_arg return annotation.from_params(params=popped_params, **subextras) File "/project/6006286/johnmg/allennlp/allennlp/common/", line 595, in from_params return retyped_subclass.from_params( File "/project/6006286/johnmg/allennlp/allennlp/common/", line 629, in from_params return constructor_to_call(**kwargs) # type: ignore File "/project/6006286/johnmg/allennlp-models/allennlp_models/generation/dataset_readers/", line 102, in __init__ raise ValueError( ValueError: Bad start or end symbol ('@start@', '@end@') for tokenizer ```

OS: Linux

Python version: 3.8.0

Output of pip freeze:

Steps to reproduce

Example source:

The proximate cause of the error can be reproduced as follows: ```python from import PretrainedTransformerTokenizer from allennlp.common.util import START_SYMBOL, END_SYMBOL tokenizer_kwargs = {"additional_special_tokens": [START_SYMBOL, END_SYMBOL]} # Case 1, don't add AllenNLPs start/end symbols to vocabulary tokenizer = PretrainedTransformerTokenizer("bert-base-uncased") start_token, end_token = tokenizer.tokenize(START_SYMBOL + " " + END_SYMBOL) # Case 2, set add_special_tokens=True (the default) in PretrainedTransformerTokenizer for a wordpiece based tokenizer # this WON'T fail tokenizer = PretrainedTransformerTokenizer("bert-base-uncased", tokenizer_kwargs=tokenizer_kwargs, add_special_tokens=False) start_token, end_token = tokenizer.tokenize(START_SYMBOL + " " + END_SYMBOL) # this WILL fail tokenizer = PretrainedTransformerTokenizer("bert-base-uncased", tokenizer_kwargs=tokenizer_kwargs, add_special_tokens=True) start_token, end_token = tokenizer.tokenize(START_SYMBOL + " " + END_SYMBOL) # Case 3, BPE-based tokenizers fail regardless # this WILL fail tokenizer = PretrainedTransformerTokenizer("distilroberta-base", tokenizer_kwargs=tokenizer_kwargs, add_special_tokens=False) start_token, end_token = tokenizer.tokenize(START_SYMBOL + " " + END_SYMBOL) # this WILL fail tokenizer = PretrainedTransformerTokenizer("distilroberta-base", tokenizer_kwargs=tokenizer_kwargs, add_special_tokens=True) start_token, end_token = tokenizer.tokenize(START_SYMBOL + " " + END_SYMBOL) ```

epwalsh commented 3 years ago

What if you just disable adding start or end symbols?

If that doesn't work for your use case, you could also change the start and end symbols to something that you know is in your transformer's vocab, like [SEP] or [CLS].

JohnGiorgi commented 3 years ago

What if you just disable adding start or end symbols?

At the very least, you would need target_add_end_token to be True so that the decoder can stop before it hits the max_decoding_length, right? If thats correct, then the ValueError is still raised even if source_add_start_token and source_add_end_token are False (I confirmed this).

If that doesn't work for your use case, you could also change the start and end symbols to something that you know is in your transformer's vocab, like [SEP] or [CLS].

Is there a clean way to change AllenNLPs START_SYMBOL AND END_SYMBOL? i.e. in a config or otherwise? I went looking but couldn't find it.

epwalsh commented 3 years ago

At the very least, you would need target_add_end_token to be True so that the decoder can stop before it hits the max_decoding_length, right?

Not necessarily. If your tokenizer adds its own end symbol, then the decoder can just go off of that.

Is there a clean way to change AllenNLPs START_SYMBOL AND END_SYMBOL? i.e. in a config or otherwise?

Yes, just set start_symbol and end_symbol in your dataset reader:

JohnGiorgi commented 3 years ago

Not necessarily. If your tokenizer adds its own end symbol, then the decoder can just go off of that.

Gotcha. That is probably the best solution for me.

Yes, just set start_symbol and end_symbol in your dataset reader:

Thanks. Not sure how I missed that.

It still seems like there is a bug here. BPE-based tokenizers will throw an error if at least one of source_add_start_token, source_add_end_token, target_add_start_token, or target_add_end_token is True, even if you took care to add AllenNLPs start and end symbols to its vocabulary, and you disable the tokenizer from adding its own special tokens with add_special_tokens=False:

from import PretrainedTransformerTokenizer
from allennlp.common.util import START_SYMBOL, END_SYMBOL

tokenizer_kwargs = {"additional_special_tokens": [START_SYMBOL, END_SYMBOL]}
tokenizer = PretrainedTransformerTokenizer("distilroberta-base", tokenizer_kwargs=tokenizer_kwargs, add_special_tokens=False)
tokenizer.tokenize(START_SYMBOL + " " + END_SYMBOL)
# >> [@start@, Ġ, @end@]

This is because the try/except block of Seq2SeqDatasetReader will fail as the tokenizer will not return a list of length two (it encodes the space as "Ġ").

What if the check was more explict. Something like:

tokens = self._source_tokenizer.tokenize(start_symbol + " " + end_symbol)
if tokens[0] != self._start_token or tokens[-1] != self._end_token:
raise ValueError(
  f"Bad start or end symbol ('{start_symbol}', '{end_symbol}') "
  f"for tokenizer {self._source_tokenizer}"
epwalsh commented 3 years ago

I think that's a good idea. Want to make a PR? Feel free to tag me as a reviewer. I also just realized the start/end_symbol parameters are missing from the docstring. Would be good to update the docstring as well.

JohnGiorgi commented 3 years ago

Sounds good, I'll do both!

JohnGiorgi commented 3 years ago

@epwalsh I took a closer look at the start_symbol and end_symbol arguments in Seq2SeqDatasetReader, and it looks like they are never used (besides in the try/except)? Either this is a bug, or their purpose is not to update the default choices of START_SYMBOL and END_SYMBOL?

epwalsh commented 3 years ago

Hmm, it's a weird way of doing it, but the result of is essentially the same as

self._start_symbol = start_symbol
self._end_symbol = end_symbol