google-research / bert

TensorFlow code and pre-trained models for BERT
https://arxiv.org/abs/1810.04805
Apache License 2.0
38.16k stars 9.6k forks source link

How to handle labels when using the BERT wordpiece tokenizer #646

Open dangal95 opened 5 years ago

dangal95 commented 5 years ago

I am trying to do multi-class sequence classification using the BERT uncased base model and tensorflow/keras. However, I have an issue when it comes to labeling my data following the BERT wordpiece tokenizer. I am unsure as to how I should modify my labels following the tokenization procedure.

I have read several open and closed issues on Github about this problem and I've also read the BERT paper published by Google. Specifically in section 4.3 of the paper there is an explanation of how to adjust the labels but I'm having trouble translating it to my case. I've also read the official BERT repository README which has a section on tokenization and mentions how to create a type of dictionary that maps the original tokens to the new tokens and that this can be used as a way to project my labels.

I have used the code provided in the README and managed to create labels in the way I think they should be. However, I am not sure if this is the correct way to do it. Below is an example of a tokenized sentence and it's labels before and after using the BERT tokenizer. Just a side-note. I have adjusted some of the code in the tokenizer so that it does not tokenize certain words based on punctuation as I would like them to remain whole.

This is the code to create the mapping:

bert_tokens = []
label_to_token_mapping = []

bert_tokens.append("[CLS]")

for token in original_tokens:
   label_to_token_mapping.append(len(bert_tokens))
   bert_tokens.extend(tokenizer.tokenize(token, ignore_set=ignore_set))

bert_tokens.append("[SEP]")
original_tokens = ['The', <start>', 'eng-30-01258617-a', '<end>', 'frailty']
tokens = ['[CLS]', 'the', '<start>', 'eng-30-01258617-a', '<end>', 'frail', '##ty', '[SEP]']
labels = [0,2, 3, 4, 1]
label_to_token_mapping = [1, 2, 3, 4, 5]

Using the mapping I adjust my label array and it becomes like the following:

labels = [0, 2, 3, 4, 1, 1]

Following this I add padding labels (let's say that the maximum sequence length is 12) and so finally my label array looks like this:

labels = [5, 0, 2, 3, 4, 1, 1, 5, 5, 5, 5, 5, 5]

As you can see since the last token (labeled 1) was split into two pieces I now label both word pieces as '1'.

I am not sure if this is correct. In section 4.3 of the paper they are labelled as 'X' but I'm not sure if this is what I should also do in my case. So in the paper (https://arxiv.org/abs/1810.04805) the following example is given:

Jim    Hen    ##son  was  a puppet  ##eer
I-PER  I-PER    X     O   O   O       X

My final goal is to input a sentence into the model and as a result get back an array which can look something like [5, 0, 0, 1, 1, 2, 3, 4, 5, 5, 5, 5 ]. So one label per word piece. Then I can reconstruct the words back together to get the original length of the sentence and therefore the way the prediction values should actually look like.

Also, another option (following the section 4.3 example from the paper would be to introduce a new label (say number '6') that is used for the word-parts. So my label would look like:

labels = [5, 0, 2, 3, 4, 1, 6, 5, 5, 5, 5, 5, 5]

After training the model for a couple of epochs I attempt to make predictions and get weird values. For example a word is marked with the label '5' for padding and padding values get marked with the label '1'. This makes me think that there is something wrong with the way I create labels. Initially I did not adjust the labels so I would leave the labels as they were originally even after tokenizing the original sentence. This did not give me good results.

Any help would be greatly appreciated as I've been trying hard to find what I should do online but I haven't been able to figure it out yet. Thank you in advance!

Also, the following is the code I use to create my model:

from tensorflow.python.keras.layers import Input, Dense
from tensorflow.python.keras.models import Model
from tensorflow.python.keras import backend as K
import tensorflow_hub as hub
import tensorflow as tf

class BertLayer(tf.layers.Layer):
    def __init__(self, bert_path, max_seq_length, n_fine_tune_layers=10, **kwargs):
        self.n_fine_tune_layers = n_fine_tune_layers
        self.trainable = True
        self.output_size = 768
        self.bert_path = bert_path
        self.max_seq_length = max_seq_length

        super(BertLayer, self).__init__(**kwargs)

    def build(self, input_shape):
        self.bert = hub.Module(
            self.bert_path,
            trainable=self.trainable,
            name="{}_module".format(self.name)
        )
        trainable_vars = self.bert.variables

        # Remove unused layers
        # trainable_vars = [var for var in trainable_vars if not "/cls/" in var.name]
        trainable_vars = [var for var in trainable_vars
                          if not ("/cls/" in var.name) and not ("/pooler/" in var.name)]

        # Select how many layers to fine tune
        trainable_vars = trainable_vars[-self.n_fine_tune_layers:]

        # Add to trainable weights
        for var in trainable_vars:
            self._trainable_weights.append(var)

        for var in self.bert.variables:
            if var not in self._trainable_weights:
                self._non_trainable_weights.append(var)

        super(BertLayer, self).build(input_shape)

    def call(self, inputs):
        inputs = [K.cast(x, dtype="int32") for x in inputs]
        input_ids, input_mask, segment_ids = inputs
        bert_inputs = dict(
            input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids
        )
        result = self.bert(inputs=bert_inputs, signature="tokens", as_dict=True)[
            "sequence_output"
        ]
        return result

    def compute_output_shape(self, input_shape):
        return input_shape[0], self.output_size

    def get_config(self):
        config = {'bert_path': self.bert_path, 'max_seq_length': self.max_seq_length}
        base_config = super(BertLayer, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

# Build model
def build_model(bert_path, max_seq_length):
    in_id = Input(shape=(None,), name="input_ids")
    in_mask = Input(shape=(None,), name="input_masks")
    in_segment = Input(shape=(None,), name="segment_ids")
    bert_inputs = [in_id, in_mask, in_segment]

    bert_output = BertLayer(bert_path=bert_path, n_fine_tune_layers=3, max_seq_length=max_seq_length)(bert_inputs)
    dense = Dense(128, activation='relu')(bert_output)
    pred = Dense(7, activation='softmax',)(dense)

    model = Model(inputs=bert_inputs, outputs=pred)
    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['categorical_accuracy'])
    model.summary()

    return model

def initialize_vars(sess):
    sess.run(tf.local_variables_initializer())
    sess.run(tf.global_variables_initializer())
    sess.run(tf.tables_initializer())
    K.set_session(sess)
ashutoshsingh0223 commented 5 years ago

We have been trying to do the same thing. One thing we tried is tagging the sequences after wordpiece tokenization. So in our case

Jim    Hen    ##son  was  a puppet  ##eer  
 B-PER  I-PER  X     O    O   O       X

becomes

Jim    Hen    ##son  was  a puppet  ##eer    
B-PER  I-PER  I-PER  O    O   O        O

And while decoding we merge the tags for subtokens of a token like this.

def convert_to_original_length(sentence, tags):
   r = []
   r_tags = []
   for index, token in enumerate(tokenizer.tokenize(sentence)):
       if token.startswith("##"):
           if r:
               r[-1] = f"{r[-1]}{token[2:]}"
       else:
           r.append(token)
           r_tags.append(tags[index])
   return r_tags

We found it work better than taking the tag of first subtoken.

dsindex commented 5 years ago

@dangal95

i had very similar problem before. in my case, i need to integrate with BERT embedding with Glove, ELMo, word embedding from CNN. there are many possible solutions to align.

  1. pick up the first token embedding of a word, just like original BERT paper.
  2. mean or max pool token embeddings of a word.

then, how to compute pooled embedding from series of token embeddings?

  1. run BERT graph to get token embeddings before running main graph.
  2. compute pooled embedding and feed it to main graph.
  3. run main graph.

if you seek fine-tuning, the method mentioned by @ashutoshsingh0223 will be better.

sougata-fiz commented 4 years ago

@ashutoshsingh0223 how did you handle the [SEP] and [CLS] tokens?

YHN-ice commented 1 year ago

We have been trying to do the same thing. One thing we tried is tagging the sequences after wordpiece tokenization. So in our case

Jim    Hen    ##son  was  a puppet  ##eer  
 B-PER  I-PER  X     O    O   O       X

becomes

Jim    Hen    ##son  was  a puppet  ##eer    
B-PER  I-PER  I-PER  O    O   O        O

And while decoding we merge the tags for subtokens of a token like this.

def convert_to_original_length(sentence, tags):
   r = []
   r_tags = []
   for index, token in enumerate(tokenizer.tokenize(sentence)):
       if token.startswith("##"):
           if r:
               r[-1] = f"{r[-1]}{token[2:]}"
       else:
           r.append(token)
           r_tags.append(tags[index])
   return r_tags

We found it work better than taking the tag of first subtoken.

I tried to adopt this schema. But I encountered certain situations that is not trivial to tackle. The problem is that, the correspondance of (sub)word indexes is hard to get. Surely the "##" is a common hint for split. But the following situations make it not sufficient:

  1. 1998-01-29 -> 1998 - 01 - 29
  2. I'm 21 years old. -> I [unk] m 21 years old The correspondance should be[^1]
  3. [(0,0),(1,5)]
  4. [(0,0),(1,3),(2,4),(3,5),(4,6)]

[^1]: here we use a list of index pair to express the correspondance, with index in the pair stands for the start of the same word.

So I used another trick free of pattern match: for every word in the original sentence, wo tokenize it separately, and collect the correspondance information:

def get_index_correspondence(sent, tokenizer):
    """
    due to cases like: "1996-08-22"=>"1996", "-", "08", "-", "22", we need exact position correspondance
    A = ["Brussels", "1996-08-22"]
    B = ["br", "##us", "##se", "##ls", "1996", "-", "08", "-", "22"]
    """
    correspondence = [(0,0)]
    for word in sent:
        (raw_end, expand_end) = correspondence[-1]
        correspondence.append((raw_end+1, expand_end+len(tokenizer.tokenize(word))))
    return correspondence