divelab / DIG

A library for graph deep learning research
https://diveintographs.readthedocs.io/
GNU General Public License v3.0
1.87k stars 283 forks source link

Compatible issue with subgraphx and torch module #75

Closed MyRespect closed 2 years ago

MyRespect commented 2 years ago

Update: Thank you so much for the developers' active response. Following @Oceanusity's advice, I have added the "arguments_read " function into our model and it works.

I have sent an email to you since I am not sure if this is caused by our self-build dataset or not.

I have successfully installed the DIG library using virtualenv: PyTorch 1.6.0, torch_geometric 1.7.0. But when I run the code, it causes bugs:

TypeError Traceback (most recent call last) /tmp/ipykernel_9260/3395182850.py in 25 prediction = model(data.x, data.edgeindex).argmax(1) 26 , explanation_results, related_preds = \ ---> 27 explainer(data.x, data.edge_index, max_nodes=max_nodes) 28 29 explanation_results = explanation_results[prediction]

~/dig-mengliu/lib/python3.8/site-packages/dig/xgraph/method/subgraphx.py in call(self, x, edge_index, **kwargs) 849 850 for label_idx, label in enumerate(ex_labels): --> 851 results, related_pred = self.explain(x, edge_index, 852 label=label, 853 max_nodes=max_nodes,

.....

~/dig-mengliu/lib/python3.8/site-packages/dig/xgraph/method/shapley.py in value_func(batch) 12 def value_func(batch): 13 with torch.no_grad(): ---> 14 logits = gnnNets(data=batch) 15 probs = F.softmax(logits, dim=-1) 16 score = probs[:, target_class]

~/dig-mengliu/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, kwargs) 720 result = self._slow_forward(*input, *kwargs) 721 else: --> 722 result = self.forward(input, kwargs) 723 for hook in itertools.chain( 724 _global_forward_hooks.values(),

TypeError: forward() got an unexpected keyword argument 'data'

Oceanusity commented 2 years ago

Hello, thank you for your issue.

In DIG xgraph, the forward function of GNN should take the torch_geometric.data.Data as a valid input.

In our implementation, the model have a function arguments_read to handle this problem.

Welcome to post issues if you have further problems.