Closed Forbu closed 1 year ago
I can push a correction with this new code I wrote :
from torch_geometric.nn.conv import MessagePassing
class MPGNNConv(MessagePassing):
def __init__(self, node_dim, edge_dim, layers=3):
super().__init__(aggr='mean', node_dim=0)
self.lin_edge = MLP(in_dim=node_dim * 2 + edge_dim, out_dim=node_dim, hidden_layers=layers)
self.lin_node = MLP(in_dim=node_dim * 2, out_dim=node_dim, hidden_layers=layers)
def forward(self, x, edge_index, edge_attr):
"""
here we apply the message passing function
and then we apply the MLPs to the output of the message passing function
"""
# message passing
message_info = self.propagate(edge_index, x=x, edge_attr=edge_attr)
# we concat the output of the message passing function with the input node features
x = torch.cat((x, message_info), dim=-1)
# now we apply the MLPs
x = self.lin_node(x)
return x, edge_attr
def message(self, x_j: Tensor, x_i: Tensor, edge_attr: Tensor):
x = torch.cat((x_i, x_j, edge_attr), dim=-1)
x = self.lin_edge(x)
return x
I think the current implementation is correct, see Section 5.1 - Message Functions. Specifically, edge features are used to compute a d x d
matrix that is used to transform node features.
You are totally right ... my bad
I was mistaken from the different architectural type ... I taugh it was equivalent of the graph bloc from http://proceedings.mlr.press/v80/sanchez-gonzalez18a/sanchez-gonzalez18a.pdf
🐛 Describe the bug
This issue does not really concern a "bug" but a interrogation about a custom graph neural network operator : the NNconv operator (https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.NNConv).
There is possibly an issue concerning this layer (but I perhaps misunderstood something). Basicly in the paper Neural Message Passing for Quantum Chemistry (https://arxiv.org/abs/1704.01212) (NNconv is the layer that reproduce the architecture of the paper) we have two operations :
Basicly we have the message passing that if a function of the edge_attrib and the 2 other nodes and we have the preprocessing of the message with the node information.
But when I look at the code of the NNconv layer it doesn't seem to represent those operations (the def message method) it doesn't seems to do that (https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/conv/nn_conv.html#NNConv):
Environment
conda
,pip
, source): piptorch-scatter
):