divelab / DIG

A library for graph deep learning research
https://diveintographs.readthedocs.io/
GNU General Public License v3.0
1.84k stars 281 forks source link

GNNExplainer: Edge mask shape for graph classification #112

Closed GiuseppeSerra93 closed 2 years ago

GiuseppeSerra93 commented 2 years ago

Hello, I am trying to use GNNExplainer on BA_2Motifs with GIN_3l as the base model. Consider the following command:

explainer = GNNExplainer(model, epochs=100, lr=0.01, explain_graph=True)
sparsity = 0.5
for i, data in enumerate(test_loader):
    print(f'Explain graph {i}')
    edge_masks, hard_edge_masks, related_preds = \
                explainer(data.x, data.edge_index,
                          sparsity=sparsity,
                          num_classes=num_classes)
    print(data.x.shape)            
    print(data.edge_index.shape)
    print(edge_masks[0].shape)
    print(hard_edge_masks[0].shape)
    print()

A snippet of the output is:

Explain graph 0
torch.Size([25, 10])
torch.Size([2, 50])
torch.Size([75])
torch.Size([75])

Explain graph 1
torch.Size([25, 10])
torch.Size([2, 52])
torch.Size([77])
torch.Size([77])

At first, I thought that edge_mask.shape=data.edge_index.shape[1] since this should be the edge-level explanation. Now, it looks like edge_mask.shape=data.edge_index.shape[1]+data.x.shape[0]. Is that how it is supposed to work?

Also, since it is undirected, does the mask remove a certain edge in both directions, e.g. (0,1) and (1,0)? I also tried something like edge_mask[data.x.shape[0]:] or edge_mask[:-data.x.shape[0]], but when I applied these masks to edge_index the results are not convincing. Am I missing something? Am I doing something wrong? Many thanks for any clarification.

alirezadizaji commented 2 years ago

I think this happens as self-loop edges are added, and number of self-loop edges are equivalent with number of nodes.

GiuseppeSerra93 commented 2 years ago

Thanks, understood. The self-loop edges are added at the end of the tensors using the PyG functions. So, this means that if I wanted to extract the selected edges only, I could do something like:

data.edge_index[:, hard_edge_masks[0][:-data.x.shape[0]].bool()]

The results would look like:

tensor([[ 0,  0,  0,  0,  0,  1,  2,  2,  4,  5,  5,  7,  8,  9, 10, 11, 12, 13,
         15, 16, 18, 20, 20, 21, 21, 22, 22, 23],
        [ 1,  4, 11, 13, 20, 14,  7, 15,  0,  2,  8,  2,  5, 19, 16,  0,  5,  0,
          2, 10,  2, 21, 24, 20, 24, 21, 23, 20]])

This brings me back to the second question. It looks like the edges are selected in one direction only, just some of them in both directions. Is this relevant in case I wanted to generate a plot to show which are the edges that are selected or not in a given graph? May I simply consider every edge in the above output to be the ones that will be highlighted in the plot? Thanks again, I just want to double-check if this is correct.