pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.1k stars 3.63k forks source link

How to update edge embeddings? #1908

Closed thu-wangz17 closed 3 years ago

thu-wangz17 commented 3 years ago

❓ Questions & Help

Hi, I have a question about how to update node embeddings and edge embeddings simultaneously. I want to update the edge embeddings for the next update of node embeddings and the edge embedding $e_{ij}^{l+1}$ is updated with the information of $x_i^l, xj^l, e{ij}^l$. I think it should be finished in message method, but it should only return a Tensor. Are there some examples? Thank you.

rusty1s commented 3 years ago

Yeah, computing this in message makes the most sense, but currently it expects to return a single Tensor. You have two options to tackle this: 1.

class MyConv(MessagePassing):
    ...
    def forward(self, x, edge_index, edge_attr):
        row, col = edge_index

        # Update edge embeddings in `forward`:
        edge_attr = self.mlp(torch.cat([x[row], edge_attr, x[col]], dim=-1)

        return self.propagate(edge_index, x=x, edge_attr=edge_attr), edge_attr

2.

class MyConv(MessagePassing):
    ...
    def forward(self, x, edge_index, edge_attr):
        out = self.propagate(edge_index, x=x, edge_attr=edge_attr)
        edge_attr = self.__edge_attr__
        self.__edge_attr__ = None
        return out, edge_attr

    def message(self, x_i, x_j, edge_attr):
        self.__edge_attr__ = self.mlp(torch.cat([x_i, edge_attr, x_j], dim=-1)
thu-wangz17 commented 3 years ago

That's very nice! Thank you very much.

scarpma commented 3 years ago

Hi again @rusty1s , I want to implement the message passing interface used in this work, defined in this article (LEARNING MESH BASED SIMULATION WITH GRAPH NETWORKS) .

I'm trying to create the "processor" network and my question is if I can implement this update scheme:

where f and g are two multilayer perceptrons (later f = self.edge_mlp and g = self.node_mlp).

My solution, based on the answer you gave in this issue, is:

class processor(MessagePassing):

    def __init__(self):
        super(processor, self).__init__(aggr='add')

    ...

    def forward(self, x, edge_index, edge_attr):
        row, col = edge_index

        # Update edge embeddings in `forward`:
        edge_attr = self.edge_mlp(torch.cat([x[row], edge_attr, x[col]], dim=-1))

        return self.propagate(edge_index, x=x, edge_attr=edge_attr), edge_attr

    def message(self, edge_attr):
        return edge_attr

    def update(self, inputs, edge_index, x):
        row, col = edge_index
        return self.node_mlp(x[row], inputs])

Do you think it is correct ? And what do you think about performances ? Maybe it could be implemented as a sparse matrix multiplication ?

rusty1s commented 3 years ago

This is correct, although I suggest to apply edge_mlp in message:

class processor(MessagePassing):

    def __init__(self):
        super(processor, self).__init__(aggr='add')

    ...

    def forward(self, x, edge_index, edge_attr):
        return self.propagate(edge_index, x=x, edge_attr=edge_attr), edge_attr

    def message(self, x_i, x_j, edge_attr):
        return self.edge_mlp(torch.cat([x_i, edge_attr, x_j], dim=-1)

    def update(self, inputs, x):
        return self.node_mlp(torch.cat([x, inputs], dim=-1)
scarpma commented 3 years ago

Thank you very much Mathias !

Do you think it could be implemented as a sparse matrix multiplication ?

rusty1s commented 3 years ago

That's not possible since you can only integrate one-dimensional edge features into a sparse matrix multiplication (that act as a weighting of neighbors). In this case, this needs to be implemented like the way above.

scarpma commented 3 years ago

Thank you very much !