divelab / DIG

A library for graph deep learning research
https://diveintographs.readthedocs.io/
GNU General Public License v3.0
1.87k stars 283 forks source link

Issue in GNNExplainer tutorial. #50

Closed het-25 closed 3 years ago

het-25 commented 3 years ago

I tried to run the tutorial for GNNexplainer. However, I am consistently getting the following value error. It would be a great help if you could help me out here.:

`--------------------------------------------------------------------------- ValueError Traceback (most recent call last)

in 8 if torch.isnan(data.y[0].squeeze()): 9 continue ---> 10 edge_masks, hard_edge_masks, related_preds = explainer(data.x, data.edge_index, sparsity=sparsity, num_classes=num_classes, node_idx=node_idx) 11 12 x_collector.collect_data(hard_edge_masks, related_preds, data.y[0].squeeze().long().item()) ~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 1049 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1050 or _global_forward_hooks or _global_forward_pre_hooks): -> 1051 return forward_call(*input, **kwargs) 1052 # Do not call functions when jit is used 1053 full_backward_hooks, non_full_backward_hooks = [], [] ~/.local/lib/python3.8/site-packages/dig/xgraph/method/gnnexplainer.py in forward(self, x, edge_index, mask_features, **kwargs) 143 self.__clear_masks__() 144 self.__set_masks__(x, self_loop_edge_index) --> 145 edge_masks.append(self.control_sparsity(self.gnn_explainer_alg(x, edge_index, ex_label), sparsity=kwargs.get('sparsity'))) 146 # edge_masks.append(self.gnn_explainer_alg(x, edge_index, ex_label)) 147 ~/.local/lib/python3.8/site-packages/dig/xgraph/method/gnnexplainer.py in gnn_explainer_alg(self, x, edge_index, ex_label, mask_features, **kwargs) 84 h = x 85 raw_preds = self.model(x=h, edge_index=edge_index, **kwargs) ---> 86 loss = self.__loss__(raw_preds, ex_label) 87 if epoch % 20 == 0 and debug: 88 print(f'Loss:{loss.item()}') ~/.local/lib/python3.8/site-packages/dig/xgraph/method/gnnexplainer.py in __loss__(self, raw_preds, x_label) 44 def __loss__(self, raw_preds: Tensor, x_label: Union[Tensor, int]): 45 if self.explain_graph: ---> 46 loss = cross_entropy_with_logit(raw_preds, x_label) 47 else: 48 loss = cross_entropy_with_logit(raw_preds[self.node_idx].unsqueeze(0), x_label) ~/.local/lib/python3.8/site-packages/dig/xgraph/method/gnnexplainer.py in cross_entropy_with_logit(y_pred, y_true, **kwargs) 10 11 def cross_entropy_with_logit(y_pred: torch.Tensor, y_true: torch.Tensor, **kwargs): ---> 12 return cross_entropy(y_pred, y_true.long(), **kwargs) 13 14 class GNNExplainer(ExplainerBase): ~/.local/lib/python3.8/site-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction) 2822 if size_average is not None or reduce is not None: 2823 reduction = _Reduction.legacy_get_string(size_average, reduce) -> 2824 return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index) 2825 2826 ValueError: Expected input batch_size (700) to match target batch_size (1).`
Oceanusity commented 3 years ago

Hello, we have updated the package now, it should be okay to run the examples with the latest version of dig.

Thank you for your issue, and feel free to open issue if you have further problems.