pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.16k stars 3.64k forks source link

to_hetero() doesn't create input layer for node types containing a space in their name. #3878

Closed mughetto closed 2 years ago

mughetto commented 2 years ago

🐛 Describe the bug

Hi,

While trying to build my own HeteroData dataset using Hetionet I noticed that the model created by the transformer torch_geometric.nn.to_hetero doesn't create input layers if node types contain a space. For example the following minimal code reproduces the issue:

import torch
from torch_geometric.data import HeteroData

import torch_geometric.transforms as T
from torch_geometric.nn import SAGEConv, to_hetero

data = HeteroData()
data["RedHammer"].num_nodes = 5
data["RedHammer"].x = torch.zeros(5, 7) 
data["GreenHammer"].num_nodes = 3
data["GreenHammer"].x = torch.zeros(3, 7)
data["Blue Hammer"].num_nodes = 11
data["Blue Hammer"].x = torch.zeros(3, 11)

edge_type_index = torch.tensor(
        [
            [1,2,3],
            [3,2,1],
        ]
    )

data["RedHammer","isSoldWith","GreenHammer"].edge_index = edge_type_index
data["RedHammer","isSoldWith","Blue Hammer"].edge_index = edge_type_index

data = T.ToUndirected()(data)

#data = T.ToUndirected()(data)

class GNN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.conv2 = SAGEConv((-1, -1), out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x

model = GNN(hidden_channels=2, out_channels=2)
model = to_hetero(model, data.metadata(), aggr='sum')

print(model)

results in:

raphModule(
  (conv1): ModuleDict(
    (RedHammer__isSoldWith__GreenHammer): SAGEConv((-1, -1), 2)
    (RedHammer__isSoldWith__Blue Hammer): SAGEConv((-1, -1), 2)
    (GreenHammer__rev_isSoldWith__RedHammer): SAGEConv((-1, -1), 2)
    (Blue Hammer__rev_isSoldWith__RedHammer): SAGEConv((-1, -1), 2)
  )
  (conv2): ModuleDict(
    (RedHammer__isSoldWith__GreenHammer): SAGEConv((-1, -1), 2)
    (RedHammer__isSoldWith__Blue Hammer): SAGEConv((-1, -1), 2)
    (GreenHammer__rev_isSoldWith__RedHammer): SAGEConv((-1, -1), 2)
    (Blue Hammer__rev_isSoldWith__RedHammer): SAGEConv((-1, -1), 2)
  )
)

def forward(self, x, edge_index):
    x__RedHammer = x.get('RedHammer')
    x__GreenHammer = x.get('GreenHammer');  x = None
    edge_index__RedHammer__isSoldWith__GreenHammer = edge_index.get(('RedHammer', 'isSoldWith', 'GreenHammer'))
    edge_index__GreenHammer__rev_isSoldWith__RedHammer = edge_index.get(('GreenHammer', 'rev_isSoldWith', 'RedHammer'));  edge_index = None
    conv1__GreenHammer = self.conv1.RedHammer__isSoldWith__GreenHammer((x__RedHammer, x__GreenHammer), edge_index__RedHammer__isSoldWith__GreenHammer)
    conv1__RedHammer1 = self.conv1.GreenHammer__rev_isSoldWith__RedHammer((x__GreenHammer, x__RedHammer), edge_index__GreenHammer__rev_isSoldWith__RedHammer);  x__GreenHammer = None
    conv1__RedHammer2 = getattr(self.conv1, "Blue Hammer__rev_isSoldWith__RedHammer")((None, x__RedHammer), (None, None));  x__RedHammer = None
    conv1__RedHammer = torch.add(conv1__RedHammer1, conv1__RedHammer2);  conv1__RedHammer1 = conv1__RedHammer2 = None
    relu__RedHammer = conv1__RedHammer.relu();  conv1__RedHammer = None
    relu__GreenHammer = conv1__GreenHammer.relu();  conv1__GreenHammer = None
    conv2__GreenHammer = self.conv2.RedHammer__isSoldWith__GreenHammer((relu__RedHammer, relu__GreenHammer), edge_index__RedHammer__isSoldWith__GreenHammer);  edge_index__RedHammer__isSoldWith__GreenHammer = None
    conv2__RedHammer1 = self.conv2.GreenHammer__rev_isSoldWith__RedHammer((relu__GreenHammer, relu__RedHammer), edge_index__GreenHammer__rev_isSoldWith__RedHammer);  relu__GreenHammer = edge_index__GreenHammer__rev_isSoldWith__RedHammer = None
    conv2__RedHammer2 = getattr(self.conv2, "Blue Hammer__rev_isSoldWith__RedHammer")((None, relu__RedHammer), (None, None));  relu__RedHammer = None
    conv2__RedHammer = torch.add(conv2__RedHammer1, conv2__RedHammer2);  conv2__RedHammer1 = conv2__RedHammer2 = None
    return {'RedHammer': conv2__RedHammer, 'GreenHammer': conv2__GreenHammer, 'Blue Hammer': None}

The takeaway being that there is no x__Blue Hammer in the forward pass.

I would have expected some sort of warning at run time or in the doc if this is an expected behavour, otherwise I suppose replacing ' ' with '_' could be a solution when generating the model?

Please let me know if I can be of any help in fixing this.

Environment

rusty1s commented 2 years ago

Thanks for reporting. This should be fixed once https://github.com/pyg-team/pytorch_geometric/pull/3882 lands.

mughetto commented 2 years ago

Thanks @rusty1s !