pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
20.92k stars 3.61k forks source link

Automatic conversion of raw Parameters for heterogeneous graph models #9547

Open sssleder opened 1 month ago

sssleder commented 1 month ago

🚀 The feature, motivation and pitch

When creating a heterogeneous model by using torch_geometric.nn.to_hetero_with_bases(), using torch.nn.parameter.Parameter directly as a layer is unsupported while layers such as torch.nn.Linear, which depend on it, are supported. I would like for this functionality to be implemented as it would be quite helpful.

An example of this is as follows:

import torch
import torch_geometric as pyg

from torch import tensor
from torch.nn.parameter import Parameter
from torch_geometric.data import HeteroData
from torch_geometric.nn import to_hetero_with_bases

class HeteroParam(nn.Module):
    def __init__(self):
        super(HeteroParam, self).__init__()

        self.linear = pyg.nn.Linear(-1, 64)

        # create skip connection weight matrix
        self.skip_weights = torch.nn.Parameter(torch.ones(64,))

    def forward(self, x, edge_index, batch):

        x = self.linear(x)

        x = x * self.skip_weights

        return x

model = HeteroParam()

data = HeteroData().from_dict(
        {'_global_store': {},
         '_1': {'x': tensor([[1., 0.],
                [1., 0.],
                [1., 0.],
                [1., 0.],
                [1., 0.],
                [1., 0.],
                [1., 0.],
                [1., 0.],
                [1., 0.],
                [1., 0.],
                [1., 0.],
                [1., 0.],
                [1., 0.]])},
         '_0': {'x': tensor([[1., 0.],
                [1., 0.],
                [1., 0.],
                [1., 1.],
                [1., 0.],
                [1., 0.],
                [1., 0.],
                [1., 0.]])},
         ('_1', '_', '_0'): {'edge_index': tensor([[ 0,  0,  1,  1,  2,  2,  2,  5,  5,  6,  7,  7,  8,  9, 10, 12],
                [ 0,  1,  0,  5,  0,  4,  1,  6,  7,  2,  5,  3,  3,  2,  1,  7]])},
         ('_1', '_', '_1'): {'edge_index': tensor([[ 0,  0,  0,  1,  1,  3,  3,  4,  5,  6,  7, 11],
                [ 3,  5,  1,  6,  0,  0,  4,  3,  0,  1, 11,  7]])},
         ('_0', '_', '_1'): {'edge_index': tensor([[ 0,  0,  0,  1,  1,  1,  2,  2,  3,  3,  4,  5,  5,  6,  7,  7],
                [ 0,  2,  1,  0,  2, 10,  6,  9,  8,  7,  2,  1,  7,  5,  5, 12]])},
         ('_0', '_', '_0'): {'edge_index': tensor([[0, 2, 3, 3],
                [3, 3, 2, 0]])}})

model = to_hetero_with_bases(model, data.metadata(), num_bases=2, in_channels={'x': 2}, debug=True)

With the current implementation this results in the following error:

opcode         name          target                   args                    kwargs
-------------  ------------  -----------------------  ----------------------  --------
placeholder    x             x                        ()                      {}
placeholder    edge_index    edge_index               ()                      {}
placeholder    batch         batch                    ()                      {}
call_module    linear        linear                   (x,)                    {}
get_attr       skip_weights  skip_weights             ()                      {}
call_function  mul           <built-in function mul>  (linear, skip_weights)  {}
output         output        output                   (mul,)                  {}

def forward(self, x, edge_index, batch):
    linear = self.linear(x);  x = None
    skip_weights = self.skip_weights
    mul = linear * skip_weights;  linear = skip_weights = None
    return mul

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Cell In[3], line 60
     26 model = HeteroParam()
     28 data = HeteroData().from_dict(
     29         {'_global_store': {},
     30          '_1': {'x': tensor([[1., 0.],
   (...)
     57          ('_0', '_', '_0'): {'edge_index': tensor([[0, 2, 3, 3],
     58                 [3, 3, 2, 0]])}})
---> 60 model = to_hetero_with_bases(model, data.metadata(), num_bases=2, in_channels={'x': 2}, debug=True)

File ~/.conda/envs/graph_mining/lib/python3.11/site-packages/torch_geometric/nn/to_hetero_with_bases_transformer.py:134, in to_hetero_with_bases(module, metadata, num_bases, in_channels, input_map, debug)
     25 r"""Converts a homogeneous GNN model into its heterogeneous equivalent
     26 via the basis-decomposition technique introduced in the
     27 `"Modeling Relational Data with Graph Convolutional Networks"
   (...)
    130         transformation in debug mode. (default: :obj:`False`)
    131 """
    132 transformer = ToHeteroWithBasesTransformer(module, metadata, num_bases,
    133                                            in_channels, input_map, debug)
--> 134 return transformer.transform()

File ~/.conda/envs/graph_mining/lib/python3.11/site-packages/torch_geometric/nn/to_hetero_with_bases_transformer.py:183, in ToHeteroWithBasesTransformer.transform(self)
    181 self._edge_offset_dict_initialized = False
    182 self._edge_type_initialized = False
--> 183 out = super().transform()
    184 del self._node_offset_dict_initialized
    185 del self._edge_offset_dict_initialized

File ~/.conda/envs/graph_mining/lib/python3.11/site-packages/torch_geometric/nn/fx.py:149, in Transformer.transform(self)
    147     self._state[node.name] = 'graph'
    148 elif node.op in ['call_module', 'call_method', 'call_function']:
--> 149     if self.has_edge_level_arg(node):
    150         self._state[node.name] = 'edge'
    151     elif self.has_node_level_arg(node):

File ~/.conda/envs/graph_mining/lib/python3.11/site-packages/torch_geometric/nn/fx.py:241, in Transformer.has_edge_level_arg(self, node)
    240 def has_edge_level_arg(self, node: Node) -> bool:
--> 241     return self._has_level_arg(node, name='edge')

File ~/.conda/envs/graph_mining/lib/python3.11/site-packages/torch_geometric/nn/fx.py:225, in Transformer._has_level_arg(self, node, name)
    222     else:
    223         return False
--> 225 return (any([_recurse(value) for value in node.args])
    226         or any([_recurse(value) for value in node.kwargs.values()]))

File ~/.conda/envs/graph_mining/lib/python3.11/site-packages/torch_geometric/nn/fx.py:225, in <listcomp>(.0)
    222     else:
    223         return False
--> 225 return (any([_recurse(value) for value in node.args])
    226         or any([_recurse(value) for value in node.kwargs.values()]))

File ~/.conda/envs/graph_mining/lib/python3.11/site-packages/torch_geometric/nn/fx.py:217, in Transformer._has_level_arg.<locals>._recurse(value)
    215 def _recurse(value: Any) -> bool:
    216     if isinstance(value, Node):
--> 217         return getattr(self, f'is_{name}_level')(value)
    218     elif isinstance(value, dict):
    219         return any([_recurse(v) for v in value.values()])

File ~/.conda/envs/graph_mining/lib/python3.11/site-packages/torch_geometric/nn/fx.py:232, in Transformer.is_edge_level(self, node)
    231 def is_edge_level(self, node: Node) -> bool:
--> 232     return self._is_level(node, name='edge')

File ~/.conda/envs/graph_mining/lib/python3.11/site-packages/torch_geometric/nn/fx.py:212, in Transformer._is_level(self, node, name)
    211 def _is_level(self, node: Node, name: str) -> bool:
--> 212     return self._state[node.name] == name

KeyError: 'skip_weights'

Alternatives

Simply using torch.nn.Linear(bias=False) is not a viable alternative to torch.nn.parameter.Parameter in all use cases.

Additional context

No response

sssleder commented 1 month ago

Ah it seems I misunderstood the purpose of the input_map argument. By setting input_map={'skip_weights': 'node'} I was able to properly set the Parameters as a node-level argument. The purpose and use of input_map was not immediately clear from the documentation.

sssleder commented 1 month ago

However, using torch_geometric.nn.to_hetero() as follows results in a NotImplementedError:

model = to_hetero(model, data.metadata(), debug=True, input_map={'skip_weights': 'node', 'linear': 'node'})
---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Cell In[17], line 64
     60 #model = to_hetero_with_bases(model, data.metadata(), num_bases=2, in_channels={'x': 2}, debug=True, input_map={'skip_weights': 'node'})#, 'linear': 'node', 'x': 'node'})
     62 print(data.metadata())
---> 64 model = to_hetero(model, data.metadata(), debug=True, input_map={'skip_weights': 'node'})

File ~/.conda/envs/graph_mining/lib/python3.11/site-packages/torch_geometric/nn/to_hetero_transformer.py:120, in to_hetero(module, metadata, aggr, input_map, debug)
     30 r"""Converts a homogeneous GNN model into its heterogeneous equivalent in
     31 which node representations are learned for each node type in
     32 :obj:`metadata[0]`, and messages are exchanged between each edge type in
   (...)
    117         transformation in debug mode. (default: :obj:`False`)
    118 """
    119 transformer = ToHeteroTransformer(module, metadata, aggr, input_map, debug)
--> 120 return transformer.transform()

File ~/.conda/envs/graph_mining/lib/python3.11/site-packages/torch_geometric/nn/fx.py:165, in Transformer.transform(self)
    163     elif is_global_pooling_op(self.module, op, node.target):
    164         op = 'call_global_pooling_module'
--> 165     getattr(self, op)(node, node.target, node.name)
    167 # Remove all unused nodes in the computation graph, i.e., all nodes
    168 # which have been replaced by node type-wise or edge type-wise variants
    169 # but which are still present in the computation graph.
    170 # We do this by iterating over the computation graph in reversed order,
    171 # and try to remove every node. This does only succeed in case there
    172 # are no users of that node left in the computation graph.
    173 for node in reversed(list(self.graph.nodes)):

File ~/.conda/envs/graph_mining/lib/python3.11/site-packages/torch_geometric/nn/to_hetero_transformer.py:191, in ToHeteroTransformer.get_attr(self, node, target, name)
    190 def get_attr(self, node: Node, target: Any, name: str):
--> 191     raise NotImplementedError

NotImplementedError: