pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.52k stars 3.69k forks source link

Bug: `GCNConv` in `to_hetero` with single node type #4271

Closed fratajcz closed 2 years ago

fratajcz commented 2 years ago

šŸ› Describe the bug

I just started using heterogeneous graphs, so far I have roughly 1 year experience with homogeneous graphs using your library. I have a graph with just 1 node type ("gene") but multiple edge types between the nodes. My HeteroData object prints as follows:

HeteroData(
  gene={
    x=[16433, 93],
    y=[16433],
    train_mask=[16433],
    test_mask=[16433],
    val_mask=False
  },
  (gene, BioPlex30HCT116, gene)={ edge_index=[2, 95370] },
  (gene, BioPlex30293T, gene)={ edge_index=[2, 156370] },
  (gene, HuRI, gene)={ edge_index=[2, 74488] }
)

I convert my GCN-based model (that runs fine with a homogeneous Data object) using the to_hetero() function.

I then pass my data into the model as follows:

out = model(data.x_dict, data.edge_index_dict)

where data is the HeteroData Object described above.

However, I get an error and I think it is because it expects edge_weight somewhere:

Traceback (most recent call last):
    out = model(data.x_dict, data.edge_index_dict)
  File "/home/fratajcz/anaconda3/envs/compat/lib/python3.7/site-packages/torch/fx/graph_module.py", line 308, in wrapped_call
    return cls_call(self, *args, **kwargs)
  File "/home/fratajcz/anaconda3/envs/compat/lib/python3.7/site-packages/torch/fx/graph_module.py", line 308, in wrapped_call
    return cls_call(self, *args, **kwargs)
  File "/home/fratajcz/anaconda3/envs/compat/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "<eval_with_key_1>", line 7, in forward
    edge_weight__gene__BioPlex30HCT116__gene = edge_weight.get(('gene', 'BioPlex30HCT116', 'gene'))
AttributeError: 'NoneType' object has no attribute 'get'

Do I have to specify edge weights in heterogeneous graphs? The Documentation doesnt mention it. It is weird that the traceback does mention only torch packages and not torch_geometric.

Cheers and thanks, Florin

Environment

rusty1s commented 2 years ago

In general, edge_weight shouldn't be required so there might be something weird going on here. How does your homogeneous model look like?

fratajcz commented 2 years ago

Thanks, I just found the problem. The forward pass of my model had an if clause that checked if edge_weight is None:

# apply graph convolutions
for i in range(len(self.gcnconv)):
    x = self.norm[i](x)
    if edge_weight is None:
        x = self.gcnconv[i](x, edge_index)
    else:
       x = self.gcnconv[i](x, edge_index, edge_weight)
     x = F.elu(x)

which obviously tripped to_heterointo thinking it would always have an edge_weight which always was None. When i print this model after calling to_hetero, the forward pass starts with:

def forward(self, x, edge_index, edge_weight = None):
    x__gene = x.get('gene');  x = None
    edge_index__gene__BioPlex30HCT116__gene = edge_index.get(('gene', 'BioPlex30HCT116', 'gene'))
    edge_index__gene__BioPlex30293T__gene = edge_index.get(('gene', 'BioPlex30293T', 'gene'))
    edge_index__gene__HuRI__gene = edge_index.get(('gene', 'HuRI', 'gene'));  edge_index = None
    edge_weight__gene__BioPlex30HCT116__gene = edge_weight.get(('gene', 'BioPlex30HCT116', 'gene'))
    edge_weight__gene__BioPlex30293T__gene = edge_weight.get(('gene', 'BioPlex30293T', 'gene'))
    edge_weight__gene__HuRI__gene = edge_weight.get(('gene', 'HuRI', 'gene'));  edge_weight = None

However, if I change the forward pass of my model to

# apply graph convolutions
for i in range(len(self.gcnconv)):
    x = self.norm[i](x)
    x = self.gcnconv[i](x, edge_index)
    x = F.elu(x)

it works and the three lines regarding the edge_weight are gone.

However, now I run into the

File "/home/ifratajcz/anaconda3/envs/compat/lib/python3.7/site-packages/torch_geometric/nn/conv/gcn_conv.py", line 163, in forward
    edge_index, edge_weight, x.size(self.node_dim),
AttributeError: 'tuple' object has no attribute 'size'

Problem because I use GCN layers in a heterogeneous setting. I thought it should work since I have only one node type, but apparently it doesnt.

rusty1s commented 2 years ago

This is definitely a bug. I will try to fix it.

rusty1s commented 2 years ago

I just fixed this in #4279 :)

fratajcz commented 2 years ago

Thanks :)

aayyad89 commented 2 years ago

I am still getting the same bug "AttributeError: 'tuple' object has no attribute 'size'" with the GCNConv

class LightGCN(torch.nn.Module):
    def __init__(self,
                 num_users,
                 num_movies, n_layers=2, 
                 embedding_dim=20):
        super().__init__() 

        self.num_users, self.num_movies = num_users, num_movies
        self.emb_dim = embedding_dim
        self.n_layers = n_layers
        self.user_emb = torch.nn.Embedding(num_users, embedding_dim, max_norm=1.0)
        self.movie_emb = torch.nn.Embedding(num_movies, embedding_dim, max_norm=1.0)

        self.n_layers = n_layers
        graph_layer = GCNConv(embedding_dim, embedding_dim,
                                   bias=False, add_self_loops=False, normalize=True)

        self.conv = HeteroConv({('user', 'rates', 'movie'): graph_layer,
                                ('movie', 'rev_rates', 'user'): graph_layer})

    def forward(self, data):

        node_id_dict = data.node_id_dict
        edge_index_dict = data.edge_index_dict

        user_emb = self.user_emb(node_id_dict['user'])
        movie_emb = self.movie_emb(node_id_dict['movie'])
        emb_dict_init = {'user': user_emb, 'movie': movie_emb}

        emb_dict = emb_dict_init
        embs = []

        for i in range(self.n_layers):
            emb_dict = self.conv(emb_dict, edge_index_dict)
            embs.append(emb_dict)

        emb_final_user = torch.stack([emb['user'] for emb in embs], -1).mean(dim=-1)
        emb_final_movie = torch.stack([emb['movie'] for emb in embs], -1).mean(dim=-1)

        return emb_final_user, emb_final_movie, emb_dict_init

Doesn't happen when I switch to GATConv

rusty1s commented 2 years ago

Yes, you can only use GCNConv for passing messages to the same node type, e.g.:

HeteroConv({
    ...
    ('movie', 'is_similar', 'movie'): GCNConv(...),
    ...
})

This is a limitation of GCNConv as it does not support bipartite message passing.

aayyad89 commented 2 years ago

Thanks, I did not get that.

kajocina commented 10 months ago

@rusty1s is this just a matter of code implementation or theoretically it cannot be applied in such scenarios?

rusty1s commented 10 months ago

I would say it is just a limitation of the operator. In particular since