pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.18k stars 3.64k forks source link

NNConv layer : MPGNN layer ? #6601

Closed Forbu closed 1 year ago

Forbu commented 1 year ago

🐛 Describe the bug

This issue does not really concern a "bug" but a interrogation about a custom graph neural network operator : the NNconv operator (https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.NNConv).

There is possibly an issue concerning this layer (but I perhaps misunderstood something). Basicly in the paper Neural Message Passing for Quantum Chemistry (https://arxiv.org/abs/1704.01212) (NNconv is the layer that reproduce the architecture of the paper) we have two operations : image

Basicly we have the message passing that if a function of the edge_attrib and the 2 other nodes and we have the preprocessing of the message with the node information.

But when I look at the code of the NNconv layer it doesn't seem to represent those operations (the def message method) it doesn't seems to do that (https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/conv/nn_conv.html#NNConv):

    def message(self, x_j: Tensor, edge_attr: Tensor) -> Tensor:
        weight = self.nn(edge_attr)
        weight = weight.view(-1, self.in_channels_l, self.out_channels)
        return torch.matmul(x_j.unsqueeze(1), weight).squeeze(1)

Environment

Forbu commented 1 year ago

I can push a correction with this new code I wrote :

from torch_geometric.nn.conv import MessagePassing

class MPGNNConv(MessagePassing):
    def __init__(self, node_dim, edge_dim, layers=3):
        super().__init__(aggr='mean', node_dim=0)
        self.lin_edge = MLP(in_dim=node_dim * 2 + edge_dim, out_dim=node_dim, hidden_layers=layers)
        self.lin_node = MLP(in_dim=node_dim * 2, out_dim=node_dim, hidden_layers=layers)

    def forward(self, x, edge_index, edge_attr):
        """
        here we apply the message passing function
        and then we apply the MLPs to the output of the message passing function
        """

        # message passing
        message_info = self.propagate(edge_index, x=x, edge_attr=edge_attr)

        # we concat the output of the message passing function with the input node features
        x = torch.cat((x, message_info), dim=-1)

        # now we apply the MLPs
        x = self.lin_node(x)

        return x, edge_attr

    def message(self, x_j: Tensor, x_i: Tensor, edge_attr: Tensor):

        x = torch.cat((x_i, x_j, edge_attr), dim=-1)
        x = self.lin_edge(x)
        return x
rusty1s commented 1 year ago

I think the current implementation is correct, see Section 5.1 - Message Functions. Specifically, edge features are used to compute a d x d matrix that is used to transform node features.

Forbu commented 1 year ago

You are totally right ... my bad

Forbu commented 1 year ago

I was mistaken from the different architectural type ... I taugh it was equivalent of the graph bloc from http://proceedings.mlr.press/v80/sanchez-gonzalez18a/sanchez-gonzalez18a.pdf image