snap-stanford / ogb

Benchmark datasets, data loaders, and evaluators for graph machine learning
https://ogb.stanford.edu
MIT License
1.94k stars 401 forks source link

Virtual node implementation wrong? #291

Closed zjost closed 2 years ago

zjost commented 2 years ago

Hello team. I was reviewing some baselines for graph property prediction and had a question about the implementation for virtual nodes, which appears in a number of places like here.

My understanding (which may be wrong) is that we treat the virtual node as just another node in the graph that's connected to all other nodes. The update equation for node i would look something like:

where the virtual node embedding is treated as "just another message", but is given slightly different treatment since it doesn't have a BondEncoding.

But since the public implementation adds the virtual node embedding to the other node embeddings before convolution (code), it seems you implement an equation like:

Is this intended? Assuming not, I think a correct implementation could include changing the convolution layer itself and making the forward pass something like:

def forward(self, x, edge_index, edge_attr, h_virt):
    edge_embedding = self.bond_encoder(edge_attr)
    out = self.mlp((1 + self.eps) *x + h_virt + self.propagate(edge_index, x=x, edge_attr=edge_embedding))

    return out

Cheers

zjost commented 2 years ago

I have a second question about the implementation. It seems the BondEncoder is instantiated in the GIN/GCN modules, but AtomEncoder is instantiated in the GNN_Node module (code). This means that there will be K different BondEncoders but only 1 AtomEncoder, where K is the number of GraphConv layers. Is this intentional? Thanks again.

weihua916 commented 2 years ago

Hi!

Regarding the virtual node, we just add the virtual node embedding (that captures global information) into the node embeddings in each round of message passing. https://github.com/snap-stanford/ogb/blob/master/examples/graphproppred/mol/conv.py#L199 You should think of the virtual node algorithm as h_i^(l) \leftarrow h_i^(l) + h_virtual^(l), which is then used as the input to the ordinary GIN convolution (with edge features).

Regarding the encoder, it is intended that BondEncoder is instantiated in each layer of message passing, while AtomEncoder is instantiated only in the input layer. This is because the node/atom embeddings are refined in each round of message passing (hence, no need to instantiate in every layer), while edge/bond embeddings are treated as input features.

zjost commented 2 years ago

Thanks for the response. I think the issue is that you're adding the message before convolution rather than after. As such, the virtual node embedding is added N times for each node update, where N is the number of neighbors. In other words, it occurs under the sum over neighbors instead of outside of it. Ideally what would happen is you would pass the virtual node embedding as another argument to the GIN layer, which would then add it here.

weihua916 commented 2 years ago

I do not think that’s an issue. That’s how we implemented the virtual node algorithm (i do not think there is established definition of the virtual node algorithm prior to our work, and we just implemented a variant of it)

zjost commented 2 years ago

Thank you for the clarification!