pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
20.54k stars 3.57k forks source link

Issue in src.new_zeros(size).scatter_add_(dim, index, src) for Heterogeneous Model #9391

Open pauvilasoler opened 4 weeks ago

pauvilasoler commented 4 weeks ago

🐛 Describe the bug

Hi,

I am running into some issues when trying to run a Heterogeneous GNN on a custom-made dataset.

Context:

Basically, my dataset is a list of HeteroData (i.e. heterogeneous graphs) objects each of which has 1 node of the type 'Ego' and 25 nodes of the type 'Alter'. The edge types are ('Alter', 'to', 'Ego') (of which there are 25 for each graph > each of the 25 'Alters' is connected to the 'Ego') and ('Alter', 'to', 'Alter') (of which there are variable numbers for each graph). The first of these edge types have attributes whereas the latter do not.

More specifically, this is what the data looks like:

image

Regarding the model, I am using the Heterogeneous Convolution Wrapper (HeteroConv) that you can see below:

from torch_geometric.nn.conv import HeteroConv, GATConv, GCNConv, GraphConv
import torch.nn as nn
from torch.nn import Module, Linear, ReLU
from torch.optim import Adam
import torch.nn.functional as F
import torch

class Model(nn.Module):
    def __init__(self, n_conv_layers, hidden_channels, out_channels):
        super().__init__()

        self.gat = GATConv((-1, -1), hidden_channels, add_self_loops=False, aggr='mean')
        self.gcn = GraphConv(-1, hidden_channels)

        self.convs = nn.ModuleList()

        for i in range(n_conv_layers):
            hetero_conv = HeteroConv({('Alter', 'to', 'Alter'): self.gcn, ('Alter', 'to', 'Ego'): self.gat}, aggr="sum")
            self.convs.append(hetero_conv)

        self.relu = ReLU()

        self.linear = Linear(hidden_channels, out_channels)

        self.optimizer = Adam(params=self.parameters())

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.to(self.device)

    def forward(self, x_dict, edge_index_dict, edge_attributes_dict):
        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict, edge_attributes_dict)
            x_dict = {key: x.relu() for key, x in x_dict.items()}
        return self.linear(x_dict['Ego']) # predictions are made on the Ego

    def train_model(self, train_data):
        self.train()
        self.optimizer.zero_grad()
        preds = [] # predictions
        ys = []
        losses = []
        i = 0
        for data in train_data:
            i = i
            out = self(data.x_dict, data.edge_indices_dict, data.edge_attributes_dict)
            loss = self.config.loss(out['Ego'], data['Ego'].y)
            loss.backward()
            self.optimizer.step()
            losses.append(loss)
            preds.append(out['Ego'])
            ys.append(data['Ego'].y)
            i = i + 1
        return losses, preds, ys

    def test_model(self, test_data):
        self.test() 
        self.optimizer.zero_grad()
        preds = []
        ys = []
        losses = []
        for data in test_data:
            out = self(data.x_dict, data.edge_index_dict, data.edge_attributes_dict)
            loss = self.config.loss(out['Ego'], data['Ego'].y)
            losses.append(loss)
            preds.append(out['Ego'])
            ys.append(data['Ego'].y)
        return losses, preds, ys

However, the issue arises when training the model as in:


train_data = graphs  # note that graphs is the dataset and is a list of HeteroData objects like the one above

model = Model(n_conv_layers=1, hidden_channels=64, out_channels=1)

model.train_model(train_data)

Here is the error message:

image

I have noticed that a similar issue was raised in https://github.com/pyg-team/pytorch_geometric/issues/4588 but the solutions provided there are not working for my (Heterogeneous) case.

On top of this, an additional exception is raised which would seem to me to be related:

image

I would appreciate any help or ideas.

Thanks a lot!

Versions

Environment (yaml)

name: predicting-GNNs channels:

dependencies:

rusty1s commented 3 weeks ago

What does data.validate() return for you?

pauvilasoler commented 3 weeks ago

data.validate() returns True for every HeteroData object in the Dataset.

I was thinking maybe it could be an issue with how the edge indices are encoded as the indices for the alters go from 1 to 25 (and maybe it should be from 0 to 24).

As an example here is how the edge indices for the relationship ('Alter', 'to', 'Ego') look like:

[[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25], [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]]

pauvilasoler commented 3 weeks ago

It was indeed this issue.

However, shouldn't data.validate() return False in cases like these where the indices are wrongly encoded?

Thanks a lot anyway!

rusty1s commented 1 week ago

data.validate() just checks for invalid edges. It cannot automatically detect whether edges are semantically incorrect.