divelab / DIG

A library for graph deep learning research
https://diveintographs.readthedocs.io/
GNU General Public License v3.0
1.81k stars 281 forks source link

AttributeError: 'GCNConv' object has no attribute '__check_input__' #206

Open FabioDataGeek opened 1 year ago

FabioDataGeek commented 1 year ago

I'm trying to run the examples for explainable GNN ( xgnn), but when loading the model exported from dig.xgraph.models get the error above.

I'm currently running the experiment in conda environment with:

pytorch 2.0.0, python 3.9, CUDA 11.7 torch_geometric 2.3.0

Here are the traceback:

`--------------------------------------------------------------------------- AttributeError Traceback (most recent call last) Cell In[8], line 10 7 if torch.isnan(data.y[0].squeeze()): 8 continue ---> 10 logits = model(data.x, data.edge_index) 11 prediction = logits[nodeidx].argmax(-1).item() 13 , explanation_results, related_preds = explainer(data.x, data.edge_index, node_idx=node_idx, max_nodes=max_nodes)

File ~/anaconda3/envs/dive_into_graphs/lib/python3.9/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, *kwargs) 1496 # If we don't have any hooks, we want to skip the rest of the logic in 1497 # this function, and just call forward. 1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): -> 1501 return forward_call(args, **kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/anaconda3/envs/dive_into_graphs/lib/python3.9/site-packages/dig/xgraph/models/models.py:164, in GCN_2l.forward(self, *args, *kwargs) 158 """ 159 :param Required[data]: Batch - input data 160 :return: 161 """ 162 x, edge_index, batch = self.arguments_read(args, **kwargs) --> 164 post_conv = self.relu1(self.conv1(x, edge_index)) 165 for conv, relu in zip(self.convs, self.relus): 166 post_conv = relu(conv(post_conv, edge_index))

File ~/anaconda3/envs/dive_into_graphs/lib/python3.9/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, *kwargs) 1496 # If we don't have any hooks, we want to skip the rest of the logic in 1497 # this function, and just call forward. 1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): -> 1501 return forward_call(args, **kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/anaconda3/envs/dive_into_graphs/lib/python3.9/site-packages/dig/xgraph/models/models.py:350, in GCNConv.forward(self, x, edge_index, edge_weight) 347 x = torch.matmul(x, self.weight) 349 # propagate_type: (x: Tensor, edge_weight: OptTensor) --> 350 out = self.propagate(edge_index, x=x, edge_weight=edge_weight, 351 size=None) 353 if self.bias is not None: 354 out += self.bias

File ~/anaconda3/envs/dive_into_graphs/lib/python3.9/site-packages/dig/xgraph/models/models.py:362, in GCNConv.propagate(self, edge_index, size, kwargs) 361 def propagate(self, edge_index: Adj, size: Size = None, kwargs): --> 362 size = self.__check_input__(edge_index, size) 364 # Run "fused" message and aggregation (if applicable). 365 if (isinstance(edge_index, SparseTensor) and self.fuse 366 and not self._explain):

File ~/anaconda3/envs/dive_into_graphs/lib/python3.9/site-packages/torch/nn/modules/module.py:1614, in Module.getattr(self, name) 1612 if name in modules: 1613 return modules[name] -> 1614 raise AttributeError("'{}' object has no attribute '{}'".format( 1615 type(self).name, name))

AttributeError: 'GCNConv' object has no attribute '__check_input__'`

Also tried with GIN_2l but got the same result as GCN_2l

The part of the code that is failing is this: `# --- Create data collector and explanation processor --- from dig.xgraph.evaluation import XCollector

x_collector = XCollector()

index = -1 node_indices = torch.where(dataset[0].test_mask * dataset[0].y != 0)[0].tolist() data = dataset[0]

from dig.xgraph.method.subgraphx import PlotUtils from dig.xgraph.method.subgraphx import find_closest_node_result

Visualization

max_nodes = 5 node_idx = node_indices[20] print(f'explain graph node {node_idx}') data.to(device) logits = model(data.x, data.edge_index) prediction = logits[node_idx].argmax(-1).item()

_, explanation_results, related_preds = explainer(data.x, data.edge_index, node_idx=node_idx, max_nodes=max_nodes)

explanation_results = explanation_results[prediction] explanation_results = explainer.read_from_MCTSInfo_list(explanation_results)

plotutils = PlotUtils(dataset_name='ba_shapes', is_show=True) explainer.visualization(explanation_results, max_nodes=max_nodes, plot_utils=plotutils, y=data.y)`

The code has been obtained from DIG repository: https://github.com/divelab/DIG/blob/dig-stable/examples/xgraph/subgraphx.ipynb

LongchaoDa commented 1 year ago

I have faced the same problem, have you solved it?

FabioDataGeek commented 1 year ago

No, currently I'm running explainable models in Pytorch Geometric, I would like to test this library but i cannot run the examples.

Krith-man commented 1 year ago

I have faced the same problem above, any updates?

Krith-man commented 1 year ago

Hello again, do we have any news related to this issue?

Qianli-Wu commented 11 months ago

I experienced the same error: AttributeError: 'GCNConv' object has no attribute '__check_input__'.

This is due to a naming convention change in torch_geometric v2.3.0, where __attribute__ changed to _attribute. This means attributes like __check_input__ and __user_args__ have been renamed to_check_input and _user_args respectively. More details can be found in this pull request: https://github.com/pyg-team/pytorch_geometric/pull/6999

To fix it:

Update the attribute names in your code. For instance, in ~/anaconda3/envs/dive_into_graphs/lib/python3.9/site-packages/dig/xgraph/models/models.py:362, change __check_input__ to _check_input. Same goes to other attributes.

Alternatively, downgrade torch_geometric to v2.2.0 or earlier with pip install torch_geometric==2.2.0 to avoid the naming issue. Please note, this may impact other parts of your code.

ZhongLIFR commented 4 months ago

I just had the same problem here. It is solved by downgrading torch_geometric to v2.2.0 pip install torch_geometric==2.2.0