pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.14k stars 3.63k forks source link

Create GraphAutoEncoder for Heterogeneous Graph #5334

Open othmanelhoufi opened 2 years ago

othmanelhoufi commented 2 years ago

🐛 Describe the bug

After several failed attempts to create a Heterogeneous Graph AutoEncoder It's time to ask for help.

Here is a sample of my Dataset:

====================
Number of graphs: 560
Number of features: {'article': 769, 'user': 772}

HeteroData(
  y=[1],
  article={ x=[1, 769] },
  user={ x=[67, 772] },
  (user, tweeted, article)={ edge_index=[2, 54] },
  (user, retweeted, user)={ edge_index=[2, 4] },
  (user, liked, user)={ edge_index=[2, 11] }
)
=============================================================
Number of nodes: 68
Number of edges: 69
Average node degree: 1.01
Has isolated nodes: True
Has self-loops: False
Is undirected: False

I tried to follow these two tutorials in the PyTorch-Geometric documentation:

And here is what I wrote:

from dataset import FakeNewsDataset
from torch.utils.data import random_split
import torch.nn as nn
from torch_geometric.nn import GCNConv, GAE, GATConv, Linear, to_hetero
from torch_geometric.loader import DataLoader
import torch 
from tqdm import tqdm

class GCNEncoder(torch.nn.Module):

  def __init__(self, in_channels, hidden_size, out_channels, dropout):
    super(GCNEncoder, self).__init__()
    self.conv1 = GCNConv(in_channels, hidden_size, add_self_loops=False)
    self.conv2 = GCNConv(hidden_size, out_channels, add_self_loops=False)
    self.dropout = nn.Dropout(dropout)

  # Our model will take the feature matrix X and the edge list
  # representation of the graph as inputs.
  def forward(self, x, edge_index):
    x = self.conv1(x, edge_index).relu()
    x = self.dropout(x)
    return self.conv2(x, edge_index)

def train(model, optimizer, train_loader, training_set):

    loss_all = 0
    for data in tqdm(train_loader, leave=False):
        optimizer.zero_grad()
        z = model.encode(data.x_dict, data.edge_index_dict)
        loss = model.recon_loss(z, data.edge_index_dict)
        loss.backward()
        optimizer.step()
        loss_all += loss.item()
    return loss_all / len(training_set)

def main():
   dataset = FakeNewsDataset(name='politifact', encoder='all-MiniLM-L6-v2')

    num_training = int(len(dataset) * 0.6)
    num_val = int(len(dataset) * 0.2)
    num_test = len(dataset) - (num_training + num_val)
    training_set, validation_set, test_set = random_split(dataset, [num_training, num_val, num_test])

    num_epochs = 4
    batch_size = 10
    in_channels, out_channels = dataset.num_features['article'], 128
    hidden_dim = 50
    dropout = 0.2

    model = GAE(GCNEncoder(in_channels, hidden_dim, out_channels, dropout))
    model = to_hetero(model, dataset[0].metadata(), aggr='sum')
    model = model.to(device)

    for epoch in tqdm(range(num_epochs)):
        loss = train(model, optimizer, train_loader, training_set)
        print("Loss is : ", loss)

if __name__ == "__main__":
    main()

FYI the error I get is: NotImplementedError: Module [GAE] is missing the required "forward" function

But when I execute the example set by PyTorch-Geometric on Github it works just fine. So I'm guessing that GAE is not working well with my Heterogeneous Graph.

Environment

rusty1s commented 2 years ago

The current GAE module does not support heterogeneous graphs. Please try

model = GCNEncoder(...)
model = to_hetero(model)

def train(...):
    z_dict = model.encode(data.x_dict, data.edge_index_dict)

    pos_edge_label_index = data['edge_type_to_predict'].edge_index
    pos_edge_label = torch.ones(pos_edge_label_index.size(1))

    neg_edge_label_index = utills.negative_sampling(...)
    neg_edge_label = torch.ones(neg_edge_label_index.size(1))

    edge_label_index = torch.cat([pos_edge_label_index, neg_edge_label_index], dim=1)
    edge_label = torch.cat([pos_edge_label, neg_edge_label], dim=0)

    z_src = z_dict['src_node_type'][edge_label_index[0]]
    z_dst = z_dict['dst_node_type'][edge_label_index[1]]

    recon = (z_src * z_dst).sum(dim=-1)
    loss = F.bce_with_logits(recon, edge_label)