pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
20.56k stars 3.57k forks source link

HeteroConv #9312

Open ZhenjiangFan opened 1 month ago

ZhenjiangFan commented 1 month ago

šŸ› Describe the bug

tempHeteroDict = {};
for key in sageLayerNameList:
    tempHeteroDict[key] = GCNConv(sage_dim_in, sage_dim_in);
self.hetero_conv = HeteroConv(tempHeteroDict, aggr='lstm');

Versions

Hi, Thank you for your amazing work.

I encountered a bug while using HeteroConv with the aggregation scheme being set as "lstm". The error detail can be found at the bottom of this message. The other aggregation schemes are working fine. The output of "collect_env.py" can be found below. It would be great if you have any suggestions.

Thank you.

% Total % Received % Xferd Average Speed Time Time Time Current Dload Upload Total Spent Left Speed 100 22068 100 22068 0 0 4222 0 0:00:05 0:00:05 --:--:-- 5593 zsh: command not found: # Collecting environment information... PyTorch version: 2.1.2 Is debug build: False CUDA used to build PyTorch: None ROCM used to build PyTorch: N/A

OS: macOS 14.4.1 (arm64) GCC version: Could not collect Clang version: 15.0.0 (clang-1500.1.0.2.5) CMake version: Could not collect Libc version: N/A

Python version: 3.11.5 | packaged by conda-forge | (main, Aug 27 2023, 03:33:12) [Clang 15.0.7 ] (64-bit runtime) Python platform: macOS-14.4.1-arm64-arm-64bit Is CUDA available: False CUDA runtime version: No CUDA CUDA_MODULE_LOADING set to: N/A GPU models and configuration: No CUDA Nvidia driver version: No CUDA cuDNN version: No CUDA HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU: Apple M3 Pro

Versions of relevant libraries: [pip3] flake8==6.0.0 [pip3] mypy-extensions==1.0.0 [pip3] numpy==1.24.3 [pip3] numpydoc==1.5.0 [pip3] torch==2.1.2 [pip3] torch-cluster==1.6.3 [pip3] torch-geometric==2.6.0 [pip3] torch-scatter==2.1.2 [pip3] torch-sparse==0.6.18 [pip3] torch-spline-conv==1.2.2 [pip3] torchdata==0.7.1 [conda] numpy 1.24.3 py311hb57d4eb_0
[conda] numpy-base 1.24.3 py311h1d85a46_0
[conda] numpydoc 1.5.0 py311hca03da5_0
[conda] torch 2.1.2 pypi_0 pypi [conda] torch-cluster 1.6.3 pypi_0 pypi [conda] torch-geometric 2.6.0 pypi_0 pypi [conda] torch-scatter 2.1.2 pypi_0 pypi [conda] torch-sparse 0.6.18 pypi_0 pypi [conda] torch-spline-conv 1.2.2 pypi_0 pypi [conda] torchdata 0.7.1 pypi_0 pypi


TypeError Traceback (most recent call last) File :32

Cell In[8], line 82, in CustomGNN.fit(self, graph_data_x, x_dict, edge_index_dict, homo_graph_data, epochs) 77 for epoch in range(epochs+1): 78 79 # Train 80 optimizer.zerograd() ---> 82 , out = self(x_dict, edge_index_dict, homo_graph_data); 84 loss = criterion(out[graph_data_x.train_mask], graph_data_x.y[graph_data_x.train_mask]) 85 acc = accuracy(out[graph_data_x.train_mask].argmax(dim=1),graph_data_x.y[graph_data_x.train_mask])

File ~/anaconda3/lib/python3.11/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, kwargs) 1516 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1517 else: -> 1518 return self._call_impl(args, kwargs)

File ~/anaconda3/lib/python3.11/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, *kwargs) 1522 # If we don't have any hooks, we want to skip the rest of the logic in 1523 # this function, and just call forward. 1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks): -> 1527 return forward_call(args, **kwargs) 1529 try: 1530 result = None

Cell In[8], line 56, in CustomGNN.forward(self, x_dict, edge_index_dict, homo_graph_data) 49 """ 50 Sage's foward function parameters 51 The first parameter is the data matrix, with rows being nodes and colomns being features 52 The Second parameter is edge index with a dimension of 2xNumberOfEdges, each row is an edge 53 """ 54 #====================================== ---> 56 out_dict = self.hetero_conv(x_dict, edge_index_dict); 57 hSum = out_dict[list(out_dict.keys())[0]]; 58 #======================================

File ~/anaconda3/lib/python3.11/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, kwargs) 1516 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1517 else: -> 1518 return self._call_impl(args, kwargs)

File ~/anaconda3/lib/python3.11/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, *kwargs) 1522 # If we don't have any hooks, we want to skip the rest of the logic in 1523 # this function, and just call forward. 1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks): -> 1527 return forward_call(args, **kwargs) 1529 try: 1530 result = None

File ~/anaconda3/lib/python3.11/site-packages/torch_geometric/nn/conv/hetero_conv.py:166, in HeteroConv.forward(self, *args_dict, **kwargs_dict) 163 out_dict[dst].append(out) 165 for key, value in out_dict.items(): --> 166 out_dict[key] = group(value, self.aggr) 168 return out_dict

File ~/anaconda3/lib/python3.11/site-packages/torch_geometric/nn/conv/hetero_conv.py:24, in group(xs, aggr) 22 else: 23 out = torch.stack(xs, dim=0) ---> 24 out = getattr(torch, aggr)(out, dim=0) 25 out = out[0] if isinstance(out, tuple) else out 26 return out

TypeError: lstm() received an invalid combination of arguments - got (Tensor, dim=int), but expected one of: