Closed chengbaitai closed 2 years ago
sure. A simple way to use RED-GNN for relation prediction is to enumerate the relations and the relation with the highest score for (h, ?, t). Since the number of relations in KG is not large, the computation cost is bearable.
To improve efficiency, you can also extract the r-digraph between h and t like in Algorithm 1, and then enumerate the different rs.
The key point is that a model can do both entity prediction and relation prediction if it can score the triple (h,r,t).
Hello! Thanks for your reply very much! I'so sorry to bother you again.
In your excellent work, scores is the matrix(50 * 3007) which indicates the score of (Qh, Qr) and (At) .
After training, I enumerated the different rs to choose the max r without change loss function.
the code is :
for i in range(n_batch):
start = i*batch_size
end = min(n_data, (i+1)*batch_size)
batch_idx = np.arange(start, end)
subs, rels, objs = self.loader.get_batch_rels(batch_idx, data='valid')
scores = []
for rel in range(self.n_rel):
score = self.model(subs, np.array([rel] * len(subs)), mode='valid').data.cpu().numpy()
score = torch.tensor(score)
pos_scores = score[[torch.arange(len(score)).cuda(),torch.LongTensor(objs).cuda()]]
scores.append(np.array(pos_scores))
scores = torch.tensor(scores)
predict_values = torch.max(scores, 0).indices.tolist()
print(predict_values)
print(rels)
I didn't change the loss function because each element in your matrix(scores) is the score of (h, r, t). But what confuses me is that results seems very poor,the predicted value is far from the actual value There are some of the results:
0.24
[7, 6, 6, 6, 3, 1, 1, 5, 10, 1, 3, 3, 5, 7, 5, 10, 10, 1, 1, 8, 3, 3, 10, 3, 7, 9, 5, 5, 10, 10, 8, 7, 1, 8, 6, 7, 1, 6, 9, 10, 9, 10, 10, 10, 10, 9, 9, 7, 11, 3]
[ 7 7 7 7 21 1 1 3 10 13 18 19 21 22 3 10 18 1 13 13 19 19 19 21 22 15 18 18 19 19 20 22 8 13 7 7 8 12 17 18 9 10 10 19 0 2 18 22 11 3]
0.24
[10, 10, 6, 1, 5, 10, 5, 7, 9, 9, 3, 5, 3, 6, 2, 9, 8, 7, 7, 6, 7, 2, 7, 7, 1, 11, 1, 5, 9, 8, 10, 6, 6, 10, 4, 10, 10, 10, 1, 10, 8, 7, 9, 10, 10, 1, 10, 1, 8, 9]
[23 0 7 8 18 0 5 7 17 15 14 5 21 6 9 9 20 6 22 6 6 17 22 22 8 16 1 3 15 20 3 6 12 15 21 0 0 0 8 0 20 12 9 10 10 13 0 8 8 15]
0.24
[10, 10, 5, 5, 3, 10, 7, 8, 1, 1, 7, 3, 10, 11, 1, 10, 3, 3, 7, 1, 6, 7, 10, 8, 7, 6, 6, 9, 9, 7, 6, 9, 9, 10, 10, 7, 10, 10, 3, 8, 10, 8, 4, 9, 10, 3, 1, 1, 6, 2]
[19 10 18 18 19 10 12 20 1 1 6 21 23 11 1 10 19 21 22 1 12 12 0 13 22 6 6 9 2 7 22 9 9 10 10 12 10 10 14 20 23 13 4 9 10 14 20 1 6 2]
0.48
[6, 0, 3, 10, 10, 7, 9, 1, 5, 9, 10, 9, 6, 8, 5, 7, 6, 1, 1, 7, 7, 10, 3, 10, 7, 7, 7, 7, 3, 3, 3, 10, 10, 3, 5, 3, 9, 5, 3, 9, 7, 7, 8, 2, 1, 6, 7, 3, 4, 5]
[ 7 11 18 0 0 22 2 8 18 19 0 2 7 8 18 7 7 20 20 22 22 10 19 23 6 12 22 12 14 14 18 10 10 14 18 21 2 5 14 9 22 12 13 15 1 6 12 21 4 5]
0.22
[10, 9, 9, 3, 3, 4, 9, 3, 9, 9, 5, 5, 9, 9, 4, 3, 3, 11, 9, 9, 3, 9, 8, 11, 7, 3, 10, 10, 6, 6, 7, 7, 11, 3, 4, 6, 7, 7, 7, 8, 2, 5, 9, 6, 11, 3, 6, 1, 1, 11]
[10 17 19 21 21 22 9 14 19 19 5 5 2 9 4 14 14 16 2 2 21 9 20 11 6 3 23 15 6 6 12 12 18 5 4 6 6 6 12 13 17 14 2 12 11 14 7 8 8 11]
0.3
[3, 3, 7, 7, 3, 10, 9, 3, 1, 3, 10, 3, 7, 3, 9, 9, 3, 3, 3, 10, 5, 7, 2, 8, 5, 4, 7, 6, 1, 9, 10, 1, 5, 3, 9, 9, 5, 5, 7, 6, 5, 1, 1, 1, 6, 1, 7, 6, 6, 2]
[18 21 22 22 14 23 2 5 8 21 16 3 15 14 9 2 3 21 3 15 5 15 17 20 3 4 7 7 8 15 19 1 3 14 15 2 5 5 7 12 14 8 13 8 12 20 22 7 12 15]
0.22
[1, 7, 8, 2, 7, 1, 7, 7, 1, 7, 8, 1, 7, 10, 10, 1, 5, 5, 5, 3, 3, 3, 9, 10, 7, 6, 6, 6, 6, 9, 9, 1, 1, 1, 6, 7, 9, 1, 1, 1, 9, 9, 7, 1, 3, 5, 6, 10, 10, 1]
[ 1 6 13 17 22 20 22 22 8 12 13 20 22 0 0 8 18 18 18 19 19 21 2 17 6 6 7 7 12 17 2 8 1 1 6 12 2 8 20 20 9 9 12 13 14 5 6 10 10 20]
0.22
[6, 10, 7, 6, 10, 10, 7, 7, 1, 7, 10, 10, 10, 7, 6, 6, 8, 1, 1, 1, 9, 5, 5, 10, 3, 10, 6, 6, 7, 3, 6, 1, 3, 2, 1, 5, 10, 10, 10, 7, 10, 6, 6, 2, 6, 5, 3, 9, 1, 10]
[ 7 18 12 6 10 10 12 12 1 6 0 0 0 7 12 12 13 20 8 8 17 18 18 19 5 15 7 6 12 18 7 8 3 17 20 3 10 10 10 12 18 22 22 15 6 3 14 15 1 10]
0.26
[3, 7, 10, 5, 7, 10, 10, 10, 10, 9, 5, 5, 10, 8]
[18 22 0 5 7 16 0 0 0 2 18 18 19 20]
Hello, sorry, I do not have time to check your code and results in detail. Based on my experience, directly using the scores output by RED-GNN is inferior to the model trained by a relation prediction loss. This may be attributed to the different representation spaces of the relation prediction task and entity prediction task. Hence, you can try to add a BCE loss on positive relations and negative relations during training.
Hello, sorry, I do not have time to check your code and results in detail. Based on my experience, directly using the scores output by RED-GNN is inferior to the model trained by a relation prediction loss. This may be attributed to the different representation spaces of the relation prediction task and entity prediction task. Hence, you can try to add a BCE loss on positive relations and negative relations during training.
Thank you very much, I have solve this problem, and I use your dataset family achieves very perferct perfermance!Thank you very much again!
Hello! It's not clear to me that you divided all knowledge graphs into facts, training set, test set and validation set? Or do facts overlap with the other three?
Hi, please refer to this issue https://github.com/AutoML-Research/RED-GNN/issues/1. The files do not overlap.
Hello! Your work is excellent and helpful for me! My research subject is Drug-Drug interaction prediction, which is a relation prediction problem. Your work have achieves great performance in (h, r, ?) task but I wonder whether it can be turned into a link prediction (h, ?, t).