Closed AndRossi closed 2 years ago
Hi @AndRossi ,
As mentioned in the paper,
For FB15k and WN18, we report results using basis decomposition (Eq. 3) with two basis functions, and a single encoding layer with 200-dimensional embeddings. For FB15k-237, we found block decomposition (Eq. 4) to perform best, using two layers with block dimension 5 × 5 and 500-dimensional embeddings.
for FB15k and WN18, they use basis decomposition instead of block decomposition. But the link prediction model you tried actually uses block decomposition.
So please first try changing this line to use RGCNBasisLayer
. You might need to change other code correspondently.
Thank you so much for your prompt reply!
I have changed that line, and of course I have changed the parameters used to build the RGCNLayer object from:
return RGCNLayer(self.h_dim, self.h_dim, self.num_rels, self.num_bases, activation=act, self_loop=True, dropout=self.dropout)
To:
return RGCNLayer(self.h_dim, self.h_dim, self.num_rels, self.num_bases, activation=act)
(I have just removed the self_loop and dropout, that are not used by the RGCNBasisLayer. After doing this, the Python code compiles.
However I encounter this error:
/home/nvidia/workspace/dbgroup/andrea/comparative_analysis/models/dgl/examples/pytorch/rgcn/utils.py:112: RuntimeWarning: divide by zero encountered in true_divide
norm = 1.0 / in_deg
This seems to lead to a keyerror later (I guess that due to the division by 0 the norm was not computed, so the "norm" key is never set):
Traceback (most recent call last):
File "link_predict.py", line 252, in <module>
main(args)
File "link_predict.py", line 167, in main
loss = model.get_loss(g, data, labels)
File "link_predict.py", line 78, in get_loss
embedding = self.forward(g)
File "link_predict.py", line 65, in forward
return self.rgcn.forward(g)
File "/home/nvidia/workspace/dbgroup/andrea/comparative_analysis/models/dgl/examples/pytorch/rgcn/model.py", line 54, in forward
layer(g)
File "/home/nvidia/anaconda3/envs/testenv/lib/python3.6/site-packages/torch/nn/modules/module.py", line 493, in __call__
result = self.forward(*input, **kwargs)
File "/home/nvidia/workspace/dbgroup/andrea/comparative_analysis/models/dgl/examples/pytorch/rgcn/layers.py", line 39, in forward
self.propagate(g)
File "/home/nvidia/workspace/dbgroup/andrea/comparative_analysis/models/dgl/examples/pytorch/rgcn/layers.py", line 100, in propagate
g.update_all(msg_func, fn.sum(msg='msg', out='h'), None)
File "/home/nvidia/anaconda3/envs/testenv/lib/python3.6/site-packages/dgl/graph.py", line 2753, in update_all
Runtime.run(prog)
File "/home/nvidia/anaconda3/envs/testenv/lib/python3.6/site-packages/dgl/runtime/runtime.py", line 11, in run
exe.run()
File "/home/nvidia/anaconda3/envs/testenv/lib/python3.6/site-packages/dgl/runtime/ir/executor.py", line 204, in run
udf_ret = fn_data(src_data, edge_data, dst_data)
File "/home/nvidia/anaconda3/envs/testenv/lib/python3.6/site-packages/dgl/runtime/scheduler.py", line 918, in _mfunc_wrapper
return mfunc(ebatch)
File "/home/nvidia/workspace/dbgroup/andrea/comparative_analysis/models/dgl/examples/pytorch/rgcn/layers.py", line 97, in msg_func
msg = msg * edges.data['norm']
File "/home/nvidia/anaconda3/envs/testenv/lib/python3.6/site-packages/dgl/frame.py", line 612, in __getitem__
return self.select_column(key)
File "/home/nvidia/anaconda3/envs/testenv/lib/python3.6/site-packages/dgl/frame.py", line 635, in select_column
col = self._frame[name]
File "/home/nvidia/anaconda3/envs/testenv/lib/python3.6/site-packages/dgl/frame.py", line 316, in __getitem__
return self._columns[name]
KeyError: 'norm'
Any suggestions on how to fix it? Thanks again for your support, I really appreciate it.
The problem is not about the numerical warning. It is mainly due to some hard-coded logic in the example code about whether the normalization factor is stored as a node feature or an edge feature.
For entity classification example, the paper calculates normalization factor based on edge type. And since incoming edges of the node will have different normalizer if they have different edge type, the normalization factor is a feature associated with edges. So in the message function of RGCNBasisLayer
, you will see edges.data['norm']
being used.
However, for link prediction task, the normalization factor is calculated as the number of in-edges of a node, regardless of edge type. So to reduce memory usage, normalization factor is stored as a node feature. Therefore, you will see the message function of RGCNBlockLayer
access nodes.data['norm']
.
The solution is simply changing the propagate
function of RGCNBasisLayer
to access norm
feature from node feature dictionary. For example, you can change it to be the following:
def propagate(self, g):
if self.num_bases < self.num_rels:
# generate all weights from bases
weight = self.weight.view(self.num_bases,
self.in_feat * self.out_feat)
weight = torch.matmul(self.w_comp, weight).view(
self.num_rels, self.in_feat, self.out_feat)
else:
weight = self.weight
if self.is_input_layer:
def msg_func(edges):
# for input layer, matrix multiply can be converted to be
# an embedding lookup using source node id
embed = weight.view(-1, self.out_feat)
index = edges.data['type'] * self.in_feat + edges.src['id']
return {'msg': embed.index_select(0, index)}
else:
def msg_func(edges):
w = weight.index_select(0, edges.data['type'])
msg = torch.bmm(edges.src['h'].unsqueeze(1), w).squeeze()
return {'msg': msg}
def apply_func(nodes):
return {'h': nodes.data['h'] * nodes.data['norm']}
g.update_all(msg_func, fn.sum(msg='msg', out='h'), apply_func)
Hi @ylfdq1118, thank you for your help. I have changed the propagate function as you suggested.
Unfortunately, I encounter another issue:
python3 link_predict.py -d FB15k --graph-batch-size 40000 --n-hidden=200 --n-bases=2 --n-layers=1 --gpu 3
Namespace(dataset='FB15k', dropout=0.2, eval_batch_size=500, evaluate_every=500, gpu=3, grad_norm=1.0, graph_batch_size=40000, graph_split_size=0.5, lr=0.01, n_bases=2, n_epochs=6000, n_hidden=200, n_layers=1, negative_sample=10, regularization=0.01)
# entities: 14951
# relations: 1345
# edges: 483142
Test graph:
# nodes: 14951, # edges: 966284
start training...
# sampled nodes: 10857
# sampled edges: 40000
/home/nvidia/workspace/dbgroup/andrea/comparative_analysis/models/dgl/examples/pytorch/rgcn/utils.py:112: RuntimeWarning: divide by zero encountered in true_divide
norm = 1.0 / in_deg
# nodes: 10857, # edges: 40000
Done edge sampling
/home/nvidia/anaconda3/envs/py37/lib/python3.7/site-packages/dgl/base.py:18: UserWarning: Initializer is not set. Use zero initializer instead. To suppress this warning, use `set_initializer` to explicitly specify which initializer to use.
warnings.warn(msg)
Epoch 0001 | Loss 0.6932 | Best MRR 0.0000 | Forward 0.2012s | Backward 0.6277s
# sampled nodes: 10925
# sampled edges: 40000
# nodes: 10925, # edges: 40000
Done edge sampling
Epoch 0002 | Loss 0.6932 | Best MRR 0.0000 | Forward 0.0877s | Backward 0.4725s
# sampled nodes: 10815
# sampled edges: 40000
# nodes: 10815, # edges: 40000
Done edge sampling
Epoch 0003 | Loss 0.6932 | Best MRR 0.0000 | Forward 0.0802s | Backward 0.4724s
# sampled nodes: 10991
# sampled edges: 40000
# nodes: 10991, # edges: 40000
Done edge sampling
Epoch 0004 | Loss 0.6932 | Best MRR 0.0000 | Forward 0.0803s | Backward 0.4722s
# sampled nodes: 10797
# sampled edges: 40000
# nodes: 10797, # edges: 40000
Done edge sampling
Epoch 0005 | Loss 0.6932 | Best MRR 0.0000 | Forward 0.0799s | Backward 0.4724s
As you can see, the sampling for nodes and edges is repeated in each epoch, and the loss function does not decrease at all. Does this depend on the new propagate function as well?
I tried as well. In fact, the loss is decreasing very slowly. The loss does not drop until epoch 50. At epoch 100, the loss goes down to 0.3695, and at epoch 150, the loss goes down to 0.2664. But MRR value is still 0.
Kipf's RGCN repo (https://github.com/tkipf/relational-gcn) only implements BasisDecomposition for entity classification. And the link prediction part is in another author Schlichtkrull's repo (https://github.com/MichSchli/RelationPrediction). However, the link prediction example in DGL only implements the BlockDecomposition in Schlichtkrull's code.
I think to solve this problem, it might be necessary to check how BasisDecomposition for link prediction is implemented in author's repo (https://github.com/MichSchli/RelationPrediction).
@AndRossi were you able to solve this problem ? Or can you point to me any other repository which has link-prediction code.
@snash4 I'm sorry, I was not able to solve the problem. You can find the link-prediction code in the original repository of one of the authors of R-GCN: https://github.com/MichSchli/RelationPrediction
However I was not able to replicate the original results with their code either.
@AndRossi thanks for the response. The original code is too slow, takes few days to complete, and I too wasn't able to match the results or may be I didn't try enough bcoz of the time it takes
Reopened. It seems that the reproducibility is still an issue.
This issue has been automatically marked as stale due to lack of activity. It will be closed if no further activity occurs. Thank you
Hi, I'm trying to use your implementation of R-GCN to replicate the R-GCN paper results on dataset FB15K.
I have launched the training with
(I have found the values for the hyperparameters in the original paper).
Unfortunately, after 6000 epochs I get the following results: MRR (raw): 0.148426 Hits (raw) @ 1: 0.026888 Hits (raw) @ 3: 0.150023 Hits (raw) @ 10: 0.471340
These values seem far from those reported in the original paper. In particular, their MRR (raw) is 0.251. (As for the hits, in the original paper they only report them in the filtered scenario, so I can not do precise comparisons; however the hits in the paper seem much better as well).
Am I doing something wrong?