Closed MyRespect closed 2 years ago
Hello, thank you for your issue.
In DIG xgraph, the forward function of GNN should take the torch_geometric.data.Data as a valid input.
In our implementation, the model have a function arguments_read to handle this problem.
Welcome to post issues if you have further problems.
Update: Thank you so much for the developers' active response. Following @Oceanusity's advice, I have added the "arguments_read " function into our model and it works.
I have sent an email to you since I am not sure if this is caused by our self-build dataset or not.
I have successfully installed the DIG library using virtualenv: PyTorch 1.6.0, torch_geometric 1.7.0. But when I run the code, it causes bugs:
TypeError Traceback (most recent call last) /tmp/ipykernel_9260/3395182850.py in
25 prediction = model(data.x, data.edgeindex).argmax(1)
26 , explanation_results, related_preds = \
---> 27 explainer(data.x, data.edge_index, max_nodes=max_nodes)
28
29 explanation_results = explanation_results[prediction]
~/dig-mengliu/lib/python3.8/site-packages/dig/xgraph/method/subgraphx.py in call(self, x, edge_index, **kwargs) 849 850 for label_idx, label in enumerate(ex_labels): --> 851 results, related_pred = self.explain(x, edge_index, 852 label=label, 853 max_nodes=max_nodes,
.....
~/dig-mengliu/lib/python3.8/site-packages/dig/xgraph/method/shapley.py in value_func(batch) 12 def value_func(batch): 13 with torch.no_grad(): ---> 14 logits = gnnNets(data=batch) 15 probs = F.softmax(logits, dim=-1) 16 score = probs[:, target_class]
~/dig-mengliu/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, kwargs) 720 result = self._slow_forward(*input, *kwargs) 721 else: --> 722 result = self.forward(input, kwargs) 723 for hook in itertools.chain( 724 _global_forward_hooks.values(),
TypeError: forward() got an unexpected keyword argument 'data'