yao8839836 / kg-bert

KG-BERT: BERT for Knowledge Graph Completion
Apache License 2.0
679 stars 141 forks source link

Positive Labels for Negative Example in Test Data (Link Prediction Task) #11

Closed spacewalk01 closed 10 months ago

spacewalk01 commented 4 years ago

In run_bert_link_prediction.py, processor._create_examples(tail_corrupt_list, "test", args.data_dir) produces input examples with labels for KG-BERT. But I noticed that in your code shown below all the examples including negative ones in the test set are getting a positive label "1". Would you please explain it? Looking forward to your response. Thank you.

def _create_examples(self, lines, set_type, data_dir):
        .....
        .....
        if set_type == "dev" or set_type == "test":

            label = "1"

            guid = "%s-%s" % (set_type, i)
            text_a = head_ent_text
            text_b = relation_text
            text_c = tail_ent_text 
            self.labels.add(label)
         .....
yao8839836 commented 4 years ago

@batselem

Hi, thank you for reading the code, link_prediction gives the rank of every enity in correct (positive) test triples, so the labels are "1".

spacewalk01 commented 4 years ago

Thank you very much for a quick response! Yes but tail_corrupt_list contains corrupted triples. How did you label them? As I skim through it, it looks like all the triples in tail_corrupt_list get "1" label.

Thank you.

spacewalk01 commented 4 years ago

Ah, even though they all get label "1", you only get the rank of the first triple which is the original correct triple: rank1 = np.where(argsort1 == 0)[0][0] Am I correct? Thanks.

yao8839836 commented 4 years ago

@batselem

Yes, all the triples in tail_corrupt_list get "1" label. This is inappropriate when writing the code,but only the "1" of the target postive tiple is used to obtain values (probability of being correct) of the target postive tiple and corrputed tiples.

        # get the dimension corresponding to current label 1 of all corrupted triples and the positive test triple
        rel_values = preds[:, all_label_ids[0]]

Here all_label_ids[0] is the label of the target positive triple. The "1" of corrupted triples are not used in any computation, and the ranks should be correct.

spacewalk01 commented 4 years ago

I understand it, Thank you very much. Excellent work!

yao8839836 commented 4 years ago

@batselem

Thank you!