pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.16k stars 3.64k forks source link

`GNNExplainer` for link prediction #1728

Closed chunyuma closed 3 years ago

chunyuma commented 4 years ago

❓ Questions & Help

Hi @rusty1s,

I wrote a GNN model by pytorch geometric to do the link prediction which refers to this example. The forward part of the model is showed below. I tried to use GNNExplainer to explain my model but got an error forward() got an unexpected keyword argument 'edge_index'. Since my forward doesn't directly use edge_index, this might cause the issue. Is it possible that I can have an example to use GNNExplainer for link prediction? Or some instructions to modify the GNNExplainer for my model? Thank you!

    def forward(self, x, adjs, link, n_id):
        for i, (edge_index, _, size) in enumerate(adjs):
            x_target = x[:size[1]]  # Target nodes are always placed first.
            x = self.convs[i]((x, x_target), edge_index)
            if i != self.num_layers - 1:
                x = F.relu(x)
                x = F.dropout(x, p=self.dropout_p, training=self.training)

        x = x[[list(np.where(n_id.numpy()==i.numpy())[0])[0] for i in link[:,0]]] * x[[list(np.where(n_id.numpy()==i.numpy())[0])[0] for i in link[:,1]]]
        x = torch.sigmoid(self.lin(x)).squeeze(1)

        return x
rusty1s commented 4 years ago

I think that the GNNExplainer can be modified to the task of link prediction. You may want to add a new function explain_link to achieve that. Note that at the moment, this is restricted to full-batch subgraphs, and cannot handle multiple adjacencies.

chunyuma commented 3 years ago

Thanks @rusty1s. But here I'm a little bit confused about the edge_index. In GNNexplainer, it basically learns the edge_mask for edge_index. However, in my model, the edge_index is the bipartite edges generated from NeighborSampler. So does this mean the GNNexplainer needs to learn three edge_mask for three edge_index of three bipartite graph (I have three layers of SAGEConv)? I looked at the function __set_masks__ of GNNExplainer source code. There is a piece of codes for setting the mask for each MessagePassing module:

        for module in self.model.modules():
            if isinstance(module, MessagePassing):
                module.__explain__ = True
                module.__edge_mask__ = self.edge_mask

So I'm curious if I need to modify this part to assign different edge_mask to different bipartite graphs. And I also need to add threes edge_mask in optimizer = torch.optim.Adam([self.edge_mask], lr=self.lr), right?

rusty1s commented 3 years ago

In general, GNNExplainer only learns a single edge mask, and thus GNNExplainer is not trivial to apply in a NeighborSampler scenario. Since GNNExplainer only operates on the L-hop neighborhood around a single node, I do not see the need to apply neighbor sampling here. As an alternative, I suggest to define a model forward function, that operates on a single edge_index and use that one for explaining nodes/links, e.g.:

def forward(self, x, edge_index):
    ...

def forward_with_sampling(self, x, adjs, link, n_id):
    ...
chunyuma commented 3 years ago

I see. Thanks @rusty1s. However, since my model has to use GraphSAGE (I used SAGEConv that you developed here) message passing scenario (which updates the target node based on K-hop neighborhood consecutive convolution) for link prediction, the NeighborSampler is needed based on the example you provided. Do you have any suggestions to treat this situation?

Although GNNExplainer normally learns a single edge_mask, based on my understanding, the same logic can be applied to multiple edge_mask each operating only on 1-hop neighborhood. In other words, if module.__edge_mask__=self.edge_mask is used to set a single edge_mask for a module, I can set three edge_mask for three SAGEConv layers. And in torch.optim.Adam([edge_mask1, edge_mask2, edge_mask3], lr=self.lr), I try to optimize three edge_mask. Do you think it can work? Thank you!

rusty1s commented 3 years ago

Sure, that can work. However, note that during inference, GraphSAGE operates on the full graph with NeighborSampler size =-1, meaning that you can use a single edge_mask for consecutive layers.

chunyuma commented 3 years ago

I see, thanks @rusty1s.

chunyuma commented 3 years ago

However, note that during inference, GraphSAGE operates on the full graph with NeighborSampler size =-1, meaning that you can use a single edge_mask for consecutive layers.

Hi @rusty1s, regarding your statement above, I'm wondering if both forward and inference function used within explain_node function of gnn_explainer will not train the parameters of all its layers but only train the edge_mask because we set self.model.eval() within explain_node and set module.__explain__ = True and module.__edge_mask__ = self.edge_mask.

rusty1s commented 3 years ago

It will not train any model parameters since those parameters are not passed to the optimizer, see here.

chunyuma commented 3 years ago

Hi @rusty1s, I want to add the laplacian loss that the GNNExaplainer author used in his code here to the loss function. But it seems like it needs the masked_adj matrix so that the model can use gradient descent to adjust the edge_mask. Is it possible that I can modify your source code so that I can get the masked_adj matrix? Thank you!

rusty1s commented 3 years ago

Yes sure, you can do that. You can get a dense version via:

masked_adj = torch.zeros(N, N)
masked_adj[edge_index[0], edge_index[1] = self.edge_mask.sigmoid()
apratim-mishra commented 2 years ago

Hi @chunyuma , were you able to get this to work ? gnnexplainer for link prediction

rusty1s commented 2 years ago

You can have a look at https://github.com/pyg-team/pytorch_geometric/discussions/4058#discussioncomment-2157357 for this :)

apratim-mishra commented 2 years ago

Hi @rusty1s , thank you, I was able to get the to_captum approach to work, based on integrated gradients I am asking for a working approach of gnnexplainer for link prediction; based on masking of 2 endpoints of a node.

Also, the 2 examples at : https://github.com/pyg-team/pytorch_geometric/blob/master/examples/gnn_explainer.py https://github.com/pyg-team/pytorch_geometric/blob/master/examples/gnn_explainer_ba_shapes.py

don't seem to work now. Screen Shot 2022-03-15 at 12 15 27 PM

rusty1s commented 2 years ago

Thanks for reporting. This is now fixed: https://github.com/pyg-team/pytorch_geometric/commit/7526c8b1a508802d9c593b350d1ef2edf18cbbdb

apratim-mishra commented 2 years ago

Hi @rusty1s ,

I am trying to understand gnnexplainer for the link prediction task.

https://cs.stanford.edu/people/jure/pubs/gnnexplainer-neurips19.pdf simply mentions that : 'When predicting a link (vj,vk), GNNEXPLAINER learns two masks XS(vj) and XS (vk ) for both endpoints of the link'

so, for the https://github.com/pyg-team/pytorch_geometric/blob/master/examples/link_pred.py example, would feature masks from the 2 endpoint nodes be a valid explanation for the link between them ?

rusty1s commented 2 years ago

Interesting. In order to let GNNExplainer support link-level tasks we would need to provide the explain_link method. It is currently a bit tricky to implement it since it requires us to learn two separate edge masks for the source and destination node. We could achieve this by duplicating the model and setting individual edge masks in both of these.

I don't think link-level support for GNNExplainer is a top prio TBH. Pinging @RBendias just in case you have additional thoughts or resources.

Alec-Stashevsky commented 2 years ago

I would also like to put in a vote for adding link-prediction functionality to the GNNExplainer class!

rusty1s commented 2 years ago

Noted :) Let's see what we can do!