I think there is an issue in the way the Heterogeneous Graph Transformer (HGT) operator (HGTConv) is currently implemented or at least with the example of it's usage.
Basically, the init method requires a data object which is (currently) is not being passed as an argument:
import os.path as osp
import torch
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch_geometric.datasets import DBLP
from torch_geometric.nn import HGTConv, Linear
path = osp.join(osp.dirname(osp.realpath(__file__)), '../../data/DBLP')
# We initialize conference node features with a single one-vector as feature:
dataset = DBLP(path, transform=T.Constant(node_types='conference'))
data = dataset[0]
print(data)
class HGT(torch.nn.Module):
def __init__(self, hidden_channels, out_channels, num_heads, num_layers):
super().__init__()
self.lin_dict = torch.nn.ModuleDict()
for node_type in data.node_types:
self.lin_dict[node_type] = Linear(-1, hidden_channels)
self.convs = torch.nn.ModuleList()
for _ in range(num_layers):
conv = HGTConv(hidden_channels, hidden_channels, data.metadata(),
num_heads)
self.convs.append(conv)
self.lin = Linear(hidden_channels, out_channels)
def forward(self, x_dict, edge_index_dict):
x_dict = {
node_type: self.lin_dict[node_type](x).relu_()
for node_type, x in x_dict.items()
}
for conv in self.convs:
x_dict = conv(x_dict, edge_index_dict)
return self.lin(x_dict['author'])
Which results in an error in the first line of the following block when initiating the object:
model = HGT(hidden_channels=64, out_channels=4, num_heads=2, num_layers=1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data, model = data.to(device), model.to(device)
🐛 Describe the bug
Hi,
I think there is an issue in the way the Heterogeneous Graph Transformer (HGT) operator (HGTConv) is currently implemented or at least with the example of it's usage.
Basically, the init method requires a data object which is (currently) is not being passed as an argument:
Which results in an error in the first line of the following block when initiating the object:
Thanks!
Versions
Environment info:
name: channels: