castorini / duobert

Multi-stage passage ranking: monoBERT + duoBERT
112 stars 15 forks source link

Segment_id for two docs #4

Open jingtaozhan opened 4 years ago

jingtaozhan commented 4 years ago

I notice that when TFrecord is generated, two documents are assigned different segment ids(1, 2). However, the type_vocab_size is 2 according to bert_config.json provided. So I wonder the actual segment ids for the two docs.

rodrigonogueira4 commented 4 years ago

Sorry, the wrong bert_config.json was uploaded for duoBERT. The correct value is type_vocab_size=3.

I updated the file (duobert-large-msmarco-pretrained-and-finetuned.zip)

Thanks for catching this!

jingtaozhan commented 4 years ago

Thanks for your response. However, the finetuned model has only two token type embeddings.

    tf_path = os.path.abspath(tf_checkpoint_path)
    init_vars = tf.train.list_variables(tf_path)
    for name, shape in init_vars:
        print(name, shape)

The shape of the token_type_embeddings is [2, 1024]. I check the modeling.py

    token_type_table = tf.get_variable(
        name=token_type_embedding_name,
        shape=[token_type_vocab_size, width],
        initializer=create_initializer(initializer_range))
    # This vocab will be small so we always do one-hot here, since it is always
    # faster for a small vocabulary.
    flat_token_type_ids = tf.reshape(token_type_ids, [-1])
    one_hot_ids = tf.one_hot(flat_token_type_ids, depth=token_type_vocab_size)
    token_type_embeddings = tf.matmul(one_hot_ids, token_type_table)
    token_type_embeddings = tf.reshape(token_type_embeddings,
                                       [batch_size, seq_length, width])
    output += token_type_embeddings

Maybe the second doc has no segment embedding due to the tf.one_hot function.

rodrigonogueira4 commented 4 years ago

Maybe the second doc has no segment embedding due to the tf.one_hot function.

Probably. BTW, did you find this bug while using it with pytorch-transformers? If so, that would explain why I've never seen an error with the TF implementation.

jingtaozhan commented 4 years ago

Yes, I used the pytorch implementation and the embedding module raised error. So do you plan to train another version to fix this?

rodrigonogueira4 commented 4 years ago

Yes, that is what I will do next. Thanks for your answers.

pertschuk commented 4 years ago

Is there a workaround for this right now?

Is tf.one_hot embedding the token_type_ids as such then? 0 -> [1 0] 1 -> [ 0 1 ] 2 -> [ 0 0 ]

if this is indeed this case, such a workaround by modifying the modeling_bert.py file in transformers is possible:

        if torch.all(torch.lt(token_type_ids, self.type_vocab_size)):
            token_type_embeddings = self.token_type_embeddings(token_type_ids)
        else:
            token_type_embeddings = torch.zeros(self.hidden_size, dtype=torch.long, device=device)
carlos-gemmell commented 4 years ago

@jingtaozhan, @pertschuk Were you able to successfully run the model with the correct type embeddings? I have the same issue with the [2,1024] tensor. How do we map the current weights to accommodate this?

carlos-gemmell commented 4 years ago

I can confirm this works correctly when loading the weights into HuggingFace (Pytorch)

The pertained dir needs to include, just by changing the file names.

class BertForPassageRanking(BertForSequenceClassification):
    def __init__(self, config):
        super().__init__(config)
        self.weight = torch.autograd.Variable(torch.ones(2, config.hidden_size),
                                              requires_grad=True)
        self.bias = torch.autograd.Variable(torch.ones(2), requires_grad=True)

bert_ranking = BertForPassageRanking.from_pretrained("saved_models/duoBERT/",
                                                     from_tf=True)
bert_ranking.classifier.weight.data = bert_ranking.weight.data
bert_ranking.classifier.bias.data = bert_ranking.bias.data
type_embed_weight = bert_ranking.bert.embeddings.token_type_embeddings.weight.data
bert_ranking.bert.embeddings.token_type_embeddings.weight.data = torch.cat((type_embed_weight, torch.zeros(1,1024)))
bert_ranking.eval()

tokenizer = BertTokenizer("saved_models/monoBERT/vocab.txt")

query = 'how can I unfollow polaris 400 emails'

bad_passage = 'Best Answer: Plastics are used in wide range of things. So it is produced in a very huge amount and its convenience is undeniable. Recycling of plastic is very important because it is made from the oil which will cause the regular depletion of this limited resource.With the recycling of plastic we can save oil and can use it for longer time. Moreover recycling do not cause harm to the quality of plastics.est Answer: Plastics are used in wide range of things. So it is produced in a very huge amount and its convenience is undeniable. Recycling of plastic is very important because it is made from the oil which will cause the regular depletion of this limited resource.'

good_passage= "polaris 400 starter. Follow polaris 400 starter to get e-mail alerts and updates on your eBay Feed. Unfollow polaris 400 starter to stop getting updates on your eBay Feed.Yay! You're now following polaris 400 starter in your eBay Feed.ollow polaris 400 starter to get e-mail alerts and updates on your eBay Feed. Unfollow polaris 400 starter to stop getting updates on your eBay Feed. Yay! You're now following polaris 400 starter in your eBay Feed."

def custom_numericalize(query, docA, docB):
    query_ids = [tokenizer.cls_token_id] + tokenizer.encode(query, add_special_tokens=False) + [tokenizer.sep_token_id]
    query_token_type_ids = [0]*len(query_ids)

    docA_ids = tokenizer.encode(docA, add_special_tokens=False) + [tokenizer.sep_token_id]
    docA_token_type_ids = [1]*len(docA_ids)

    docB_ids = tokenizer.encode(docB, add_special_tokens=False) + [tokenizer.sep_token_id]
    docB_token_type_ids = [2]*len(docB_ids)

    input_ids = torch.tensor(query_ids+docA_ids+docB_ids).unsqueeze(0)
    input_type_ids = torch.tensor(query_token_type_ids+docA_token_type_ids+docB_token_type_ids).unsqueeze(0)
    return input_ids, input_type_ids

input_ids, input_type_ids = custom_numericalize(query, good_passage, bad_passage)
outputs = bert_ranking(input_ids, token_type_ids=input_type_ids)
outputs # tensor([[-0.2688,  0.3415]])

input_ids, input_type_ids = custom_numericalize(query, bad_passage, good_passage)
outputs = bert_ranking(input_ids, token_type_ids=input_type_ids)
outputs # tensor([[ 0.5154, -0.5022]])

The outputs are flipping as expected.