yao8839836 / kg-bert

KG-BERT: BERT for Knowledge Graph Completion
Apache License 2.0
679 stars 141 forks source link

你好 #12

Open lushishuai opened 4 years ago

lushishuai commented 4 years ago

你好,请问咱们这个代码Hits@指标能跑到论文中的实验结果吗?

yao8839836 commented 4 years ago

@lushishuai

您好,WN18RR的hits@10表现也比较不稳定, UMLS和FB15k-237的Hits@10比较稳定。可参考 #9 。

MedyG commented 3 years ago

@yao8839836 你好,我运行run_bert_relation_prediction.py得到的结果与论文结果相去甚远,是参数设置的不对吗

02/04/2021 18:31:46 - INFO - main - Test results 02/04/2021 18:31:46 - INFO - main - acc = 0.6901355995327657 02/04/2021 18:31:46 - INFO - main - eval_loss = 0.7939625151198486 02/04/2021 18:31:46 - INFO - main - global_step = 301980 02/04/2021 18:31:46 - INFO - main - loss = 0.3695593820695463

我的参数输入

--task_name kg --do_train --do_eval --do_predict --data_dir ./data/FB15K --bert_model bert-base-cased --max_seq_length 25 --train_batch_size 32 --learning_rate 5e-5 --num_train_epochs 20.0 --output_dir ./output_FB15K/ --gradient_accumulation_steps 1 --eval_batch_size 512

原代码读取entity2text.txt文件报字符错误,我换成了'utf'格式

with open(os.path.join(data_dir, "entity2text.txt"), 'r', encoding='utf') as f:
        ent_lines = f.readlines()
        for line in ent_lines:
            temp = line.strip().split('\t')
            ent2text[temp[0]] = temp[1]

    if data_dir.find("FB15") != -1:
        with open(os.path.join(data_dir, "entity2text.txt"), 'r', encoding='utf') as f:
            ent_lines = f.readlines()
            for line in ent_lines:
                temp = line.strip().split('\t')
                #first_sent_end_position = temp[1].find(".")
                ent2text[temp[0]] = temp[1]#[:first_sent_end_position + 1]              

另外由于网络环境问题,在线下载模型比较慢,我手动换成读取已经下载好的本地模型bert-base-cased-vocab.txtbert-base-cased.tar.gz

  tokenizer = BertTokenizer.from_pretrained("D:\\Documents\\code\\pretrained\\bert\\bert-base-cased-vocab.txt", do_lower_case=args.do_lower_case)
  model = BertForSequenceClassification.from_pretrained("D:\\Documents\\code\\pretrained\\bert\\bert-base-cased.tar.gz",
          cache_dir=cache_dir,
          num_labels=num_labels)

希望能得到你的帮助