pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.23k stars 3.64k forks source link

HeteroConv fails when only GCNConv are used #8034

Closed ver228 closed 1 year ago

ver228 commented 1 year ago

🐛 Describe the bug

It seems that if only GCNConv are used on a HeteroConv the forward pass fails.

Below is a minimum example:

import torch
import torch_geometric.transforms as T
from torch_geometric.datasets import OGB_MAG
from torch_geometric.nn import HeteroConv, GCNConv, SAGEConv, GATConv

dataset = OGB_MAG(root='./data', preprocess='metapath2vec', transform=T.ToUndirected())
data = dataset[0]

hidden_channels = 32

conv_het = HeteroConv({
            ('paper', 'cites', 'paper'): GCNConv(-1, hidden_channels, add_self_loops=False),
            ('author', 'writes', 'paper'): SAGEConv((-1, -1), hidden_channels),
            ('paper', 'rev_writes', 'author'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
        }, aggr='sum')

conv_only_gcn = HeteroConv({
            ('paper', 'cites', 'paper'): GCNConv(-1, hidden_channels, add_self_loops=False),
            ('author', 'writes', 'paper'): GCNConv(-1, hidden_channels, add_self_loops=False),
            ('paper', 'rev_writes', 'author'): GCNConv(-1, hidden_channels, add_self_loops=False),
        }, aggr='sum')

# Success!!! 
out1 = conv_het(data.x_dict, data.edge_index_dict)

# Fails :(
out2 = conv_only_gcn(data.x_dict, data.edge_index_dict)

The error message I have is:

  File "/opt/pyenv/versions/3.11.4/lib/python3.11/site-packages/torch_geometric/nn/conv/gcn_conv.py", line 211, in forward
    edge_index, edge_weight, x.size(self.node_dim),
                             ^^^^^^
AttributeError: 'tuple' object has no attribute 'size'

Thanks for your help!

Environment

EdisonLeeeee commented 1 year ago

That's because GCNConv does not support heterogenerous graphs, check here for all supported operators (marked as bipartite).

ver228 commented 1 year ago

Thanks a lot! Now this make sense :)

Do you know if there is a way to access to this information programatically? I am trying to dynamically change the base GNN for GCNConv, and it will be nice to throw a more informative error if somebody tries to do the same mistake i was doing.

EdisonLeeeee commented 1 year ago

Actually, the error message has been added in PyG via #7637, you might need to upgrade your PyG verison :)

ver228 commented 1 year ago

Thanks! I see the change hasn't come to a release version yet since the last release version is from the 27 Apr while the change was merged on the 23 Jun. I am happy to wait to the next release. Thanks for your help!