Closed chunyuma closed 3 years ago
I think that the GNNExplainer
can be modified to the task of link prediction. You may want to add a new function explain_link
to achieve that. Note that at the moment, this is restricted to full-batch subgraphs, and cannot handle multiple adjacencies.
Thanks @rusty1s. But here I'm a little bit confused about the edge_index
. In GNNexplainer
, it basically learns the edge_mask
for edge_index
. However, in my model, the edge_index
is the bipartite edges generated from NeighborSampler
. So does this mean the GNNexplainer
needs to learn three edge_mask
for three edge_index
of three bipartite graph (I have three layers of SAGEConv
)? I looked at the function __set_masks__
of GNNExplainer
source code. There is a piece of codes for setting the mask for each MessagePassing
module:
for module in self.model.modules():
if isinstance(module, MessagePassing):
module.__explain__ = True
module.__edge_mask__ = self.edge_mask
So I'm curious if I need to modify this part to assign different edge_mask
to different bipartite graphs. And I also need to add threes edge_mask
in optimizer = torch.optim.Adam([self.edge_mask], lr=self.lr)
, right?
In general, GNNExplainer
only learns a single edge mask, and thus GNNExplainer
is not trivial to apply in a NeighborSampler
scenario. Since GNNExplainer
only operates on the L
-hop neighborhood around a single node, I do not see the need to apply neighbor sampling here. As an alternative, I suggest to define a model forward
function, that operates on a single edge_index
and use that one for explaining nodes/links, e.g.:
def forward(self, x, edge_index):
...
def forward_with_sampling(self, x, adjs, link, n_id):
...
I see. Thanks @rusty1s. However, since my model has to use GraphSAGE
(I used SAGEConv
that you developed here) message passing scenario (which updates the target node based on K-hop neighborhood consecutive convolution) for link prediction, the NeighborSampler
is needed based on the example you provided. Do you have any suggestions to treat this situation?
Although GNNExplainer
normally learns a single edge_mask
, based on my understanding, the same logic can be applied to multiple edge_mask
each operating only on 1-hop neighborhood. In other words, if module.__edge_mask__=self.edge_mask
is used to set a single edge_mask
for a module, I can set three edge_mask
for three SAGEConv
layers. And in torch.optim.Adam([edge_mask1, edge_mask2, edge_mask3], lr=self.lr)
, I try to optimize three edge_mask
. Do you think it can work? Thank you!
Sure, that can work. However, note that during inference, GraphSAGE
operates on the full graph with NeighborSampler
size =-1
, meaning that you can use a single edge_mask
for consecutive layers.
I see, thanks @rusty1s.
However, note that during inference, GraphSAGE operates on the full graph with NeighborSampler size =-1, meaning that you can use a single edge_mask for consecutive layers.
Hi @rusty1s, regarding your statement above, I'm wondering if both forward and inference function used within explain_node function of gnn_explainer
will not train the parameters of all its layers but only train the edge_mask
because we set self.model.eval()
within explain_node
and set module.__explain__ = True
and module.__edge_mask__ = self.edge_mask
.
It will not train any model parameters since those parameters are not passed to the optimizer, see here.
Hi @rusty1s, I want to add the laplacian loss that the GNNExaplainer author used in his code here to the loss function. But it seems like it needs the masked_adj
matrix so that the model can use gradient descent to adjust the edge_mask
. Is it possible that I can modify your source code so that I can get the masked_adj
matrix? Thank you!
Yes sure, you can do that. You can get a dense version via:
masked_adj = torch.zeros(N, N)
masked_adj[edge_index[0], edge_index[1] = self.edge_mask.sigmoid()
Hi @chunyuma , were you able to get this to work ? gnnexplainer for link prediction
You can have a look at https://github.com/pyg-team/pytorch_geometric/discussions/4058#discussioncomment-2157357 for this :)
Hi @rusty1s , thank you, I was able to get the to_captum approach to work, based on integrated gradients I am asking for a working approach of gnnexplainer for link prediction; based on masking of 2 endpoints of a node.
Also, the 2 examples at : https://github.com/pyg-team/pytorch_geometric/blob/master/examples/gnn_explainer.py https://github.com/pyg-team/pytorch_geometric/blob/master/examples/gnn_explainer_ba_shapes.py
don't seem to work now.
Thanks for reporting. This is now fixed: https://github.com/pyg-team/pytorch_geometric/commit/7526c8b1a508802d9c593b350d1ef2edf18cbbdb
Hi @rusty1s ,
I am trying to understand gnnexplainer for the link prediction task.
https://cs.stanford.edu/people/jure/pubs/gnnexplainer-neurips19.pdf simply mentions that : 'When predicting a link (vj,vk), GNNEXPLAINER learns two masks XS(vj) and XS (vk ) for both endpoints of the link'
so, for the https://github.com/pyg-team/pytorch_geometric/blob/master/examples/link_pred.py example, would feature masks from the 2 endpoint nodes be a valid explanation for the link between them ?
Interesting. In order to let GNNExplainer
support link-level tasks we would need to provide the explain_link
method. It is currently a bit tricky to implement it since it requires us to learn two separate edge masks for the source and destination node. We could achieve this by duplicating the model and setting individual edge masks in both of these.
I don't think link-level support for GNNExplainer
is a top prio TBH. Pinging @RBendias just in case you have additional thoughts or resources.
I would also like to put in a vote for adding link-prediction functionality to the GNNExplainer class!
Noted :) Let's see what we can do!
❓ Questions & Help
Hi @rusty1s,
I wrote a GNN model by pytorch geometric to do the link prediction which refers to this example. The forward part of the model is showed below. I tried to use GNNExplainer to explain my model but got an error
forward() got an unexpected keyword argument 'edge_index'
. Since my forward doesn't directly useedge_index
, this might cause the issue. Is it possible that I can have an example to use GNNExplainer for link prediction? Or some instructions to modify the GNNExplainer for my model? Thank you!