divelab / DIG

A library for graph deep learning research
https://diveintographs.readthedocs.io/
GNU General Public License v3.0
1.84k stars 281 forks source link

Node Classification Explainer Raises Error Deeplift, GNNExplainer #129

Closed matekenya closed 2 years ago

matekenya commented 2 years ago

I am attempting to use various explainers to explain a node classification output.

from dig.xgraph.method import DeepLIFT, GNN_LRP, GNNExplainer, GradCAM explainer = GNNExplainer(model, explain_graph=False) explainer(data.x, data.edge_index, node_idx=node_indices[20], sparcity=0.5)

Following the documentation, the explainer, requires x, edge_index and node_idx (int)

https://diveintographs.readthedocs.io/en/latest/_modules/dig/xgraph/method/gnnexplainer.html#GNNExplainer

However, when I assign a variable to node_idx it raises the error below;

image
CM-BF commented 2 years ago

Thank you for your issue! I've fixed this bug. BTW, there are some errors in your code, you may use

edge_masks, hard_edge_masks, related_preds = explainer(data.x, data.edge_index, sparsity=sparsity, num_classes=num_classes, node_idx=node_idx)

instead of

explainer(data.x, data.edge_index, node_idx=13, sparcity=0.5)  # sparcity -> sparsity

Please refer to the example for more details.

Oliver-0423 commented 1 year ago

@CM-BF i still have this problem when i run it on google colab, when i run it on my windows laptop, it shows another error: KeyError: 'explain_message' . my dig version is 1.0.0 ,pyg 2.1.0 ,thank you

CM-BF commented 1 year ago

@CM-BF i still have this problem when i run it on google colab, when i run it on my windows laptop, it shows another error: KeyError: 'explain_message' . my dig version is 1.0.0 ,pyg 2.1.0 ,thank you

Can you please post the code you run and the corresponding error reports?

Oliver-0423 commented 1 year ago

@CM-BF KeyError Traceback (most recent call last) in 1 node_idx=torch.tensor(10) 2 num_classes=train_dataset.num_classes ----> 3 edge_masks, hard_edge_masks, related_preds = explainer(data.x, data.edge_index, sparsity=sparsity, num_classes=num_classes, node_idx=node_idx) 4

8 frames /usr/local/lib/python3.7/dist-packages/torch_geometric/nn/conv/utils/inspector.py in distribute(self, func_name, kwargs) 52 def distribute(self, func_name, kwargs: Dict[str, Any]): 53 out = {} ---> 54 for key, param in self.params[func_name].items(): 55 data = kwargs.get(key, inspect.Parameter.empty) 56 if data is inspect.Parameter.empty:

KeyError: 'explain_message'