pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.09k stars 3.63k forks source link

Lazy Initialized SAGEConv layer cannot be used with LSTM aggregator #5928

Open John-Atha opened 1 year ago

John-Atha commented 1 year ago

I am working on the heterogeneous link prediction example of the official PyG GitHub repository. I am working with various configurations of the GNNEncoder, using mainly the SAGEConv layer. As stated in the Heterogeneous Graph Learning docs, I am using to_hetero function, and the lazy initialization feature (in_channels=-1) at the SAGEConv layer. I am trying to use the lstm aggregator on the SAGEConv layers (aggr=“lstm”):

class GNNEncoder(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv(-1, hidden_channels, aggr="lstm")
        self.conv2 = SAGEConv(-1, out_channels, aggr="lstm")

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x

class EdgeDecoder(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.lin1 = Linear(2 * hidden_channels, hidden_channels)
        self.lin2 = Linear(hidden_channels, 1)

    def forward(self, z_dict, edge_label_index):
        row, col = edge_label_index
        z = torch.cat([z_dict['user'][row], z_dict['movie'][col]], dim=-1)
        z = self.lin1(z).relu()
        z = self.lin2(z)
        return z.view(-1)

class Model(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.encoder = GNNEncoder(hidden_channels, hidden_channels)
        self.encoder = to_hetero(self.encoder, data.metadata(), aggr='sum')
        self.decoder = EdgeDecoder(hidden_channels)

    def forward(self, x_dict, edge_index_dict, edge_label_index):
        z_dict = self.encoder(x_dict, edge_index_dict)
        return self.decoder(z_dict, edge_label_index)

model = Model(hidden_channels=32).to(device)

but I keep getting the following error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In [6], line 1
----> 1 model = Model(hidden_channels=32).to(device)
      3 # Due to lazy initialization, we need to run one model step so the number
      4 # of parameters can be inferred:
      5 with torch.no_grad():

Cell In [3], line 28, in Model.__init__(self, hidden_channels)
     26 def __init__(self, hidden_channels):
     27     super().__init__()
---> 28     self.encoder = GNNEncoder(hidden_channels, hidden_channels)
     29     self.encoder = to_hetero(self.encoder, data.metadata(), aggr='sum')
     30     self.decoder = EdgeDecoder(hidden_channels)

Cell In [3], line 4, in GNNEncoder.__init__(self, hidden_channels, out_channels)
      2 def __init__(self, hidden_channels, out_channels):
      3     super().__init__()
----> 4     self.conv1 = SAGEConv(-1, hidden_channels, aggr="lstm")
      5     self.conv2 = SAGEConv(-1, out_channels, aggr="lstm")

File ~/diploma/environ/lib/python3.9/site-packages/torch_geometric/nn/conv/sage_conv.py:92, in SAGEConv.__init__(self, in_channels, out_channels, aggr, normalize, root_weight, project, bias, **kwargs)
     89     kwargs['aggr_kwargs'].setdefault('in_channels', in_channels[0])
     90     kwargs['aggr_kwargs'].setdefault('out_channels', in_channels[0])
---> 92 super().__init__(aggr, **kwargs)
     94 if self.project:
     95     self.lin = Linear(in_channels[0], in_channels[0], bias=True)

File ~/diploma/environ/lib/python3.9/site-packages/torch_geometric/nn/conv/message_passing.py:125, in MessagePassing.__init__(self, aggr, aggr_kwargs, flow, node_dim, decomposed_layers, **kwargs)
    123 elif isinstance(aggr, (str, Aggregation)):
    124     self.aggr = str(aggr)
--> 125     self.aggr_module = aggr_resolver(aggr, **(aggr_kwargs or {}))
    126 elif isinstance(aggr, (tuple, list)):
    127     self.aggr = [str(x) for x in aggr]

File ~/diploma/environ/lib/python3.9/site-packages/torch_geometric/nn/resolver.py:93, in aggregation_resolver(query, *args, **kwargs)
     86 aggrs = [
     87     aggr for aggr in vars(aggr).values()
     88     if isinstance(aggr, type) and issubclass(aggr, base_cls)
     89 ]
     90 aggr_dict = {
     91     'add': aggr.SumAggregation,
     92 }
---> 93 return resolver(aggrs, aggr_dict, query, base_cls, *args, **kwargs)

File ~/diploma/environ/lib/python3.9/site-packages/torch_geometric/nn/resolver.py:33, in resolver(classes, class_dict, query, base_cls, *args, **kwargs)
     31 if query_repr in [cls_repr, cls_repr.replace(base_cls_repr, '')]:
     32     if inspect.isclass(cls):
---> 33         obj = cls(*args, **kwargs)
     34         assert callable(obj)
     35         return obj

File ~/diploma/environ/lib/python3.9/site-packages/torch_geometric/nn/aggr/lstm.py:26, in LSTMAggregation.__init__(self, in_channels, out_channels, **kwargs)
     24 self.in_channels = in_channels
     25 self.out_channels = out_channels
---> 26 self.lstm = LSTM(in_channels, out_channels, batch_first=True, **kwargs)
     27 self.reset_parameters()

File ~/diploma/environ/lib/python3.9/site-packages/torch/nn/modules/rnn.py:675, in LSTM.__init__(self, *args, **kwargs)
    673 print([i for i in kwargs])
    674 print("lalalal")
--> 675 super(LSTM, self).__init__('LSTM', *args, **kwargs)

File ~/diploma/environ/lib/python3.9/site-packages/torch/nn/modules/rnn.py:69, in RNNBase.__init__(self, mode, input_size, hidden_size, num_layers, bias, batch_first, dropout, bidirectional, proj_size, device, dtype)
     67     raise ValueError("proj_size should be a positive integer or zero to disable projections")
     68 if proj_size >= hidden_size:
---> 69     raise ValueError("proj_size has to be smaller than hidden_size")
     71 if mode == 'LSTM':
     72     gate_size = 4 * hidden_size

ValueError: proj_size has to be smaller than hidden_size

As far as I can understand the error is caused because hidden_size is equal to -1, because of the lazy initialization, and proj_size is equal to 0 by default. Given the fact that the proj_size must always be positive and smaller than hidden_size at the same time, Is there a way to use the lazy initialization feature for the SAGEConv layer, using the lstm aggregator?

Environment

rusty1s commented 1 year ago

Interesting. I guess for now you can resolve it by writing your model as

class GNNEncoder(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.lin = Linear(-1, hidden_channels)
        self.conv1 = SAGEConv(hidden_channels, hidden_channels, aggr="lstm")
        self.conv2 = SAGEConv(hidden_channels, out_channels, aggr="lstm")

    def forward(self, x, edge_index):
        x = self.lin(x).relu()
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x

Will also look into supporting -1 for LSTM-style aggregation.