pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
20.54k stars 3.57k forks source link

Issue in implementation of HGTConv #9347

Closed pauvilasoler closed 1 month ago

pauvilasoler commented 1 month ago

🐛 Describe the bug

Hi,

I think there is an issue in the way the Heterogeneous Graph Transformer (HGT) operator (HGTConv) is currently implemented or at least with the example of it's usage.

Basically, the init method requires a data object which is (currently) is not being passed as an argument:

import os.path as osp

import torch
import torch.nn.functional as F

import torch_geometric.transforms as T
from torch_geometric.datasets import DBLP
from torch_geometric.nn import HGTConv, Linear

path = osp.join(osp.dirname(osp.realpath(__file__)), '../../data/DBLP')
# We initialize conference node features with a single one-vector as feature:
dataset = DBLP(path, transform=T.Constant(node_types='conference'))
data = dataset[0]
print(data)

class HGT(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels, num_heads, num_layers):
        super().__init__()

        self.lin_dict = torch.nn.ModuleDict()
        for node_type in data.node_types:
            self.lin_dict[node_type] = Linear(-1, hidden_channels)

        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = HGTConv(hidden_channels, hidden_channels, data.metadata(),
                           num_heads)
            self.convs.append(conv)

        self.lin = Linear(hidden_channels, out_channels)

    def forward(self, x_dict, edge_index_dict):
        x_dict = {
            node_type: self.lin_dict[node_type](x).relu_()
            for node_type, x in x_dict.items()
        }

        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict)

        return self.lin(x_dict['author'])

Which results in an error in the first line of the following block when initiating the object:

model = HGT(hidden_channels=64, out_channels=4, num_heads=2, num_layers=1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data, model = data.to(device), model.to(device)

Thanks!

Versions

Environment info:

name: channels: