rampasek / GraphGPS

Recipe for a General, Powerful, Scalable Graph Transformer
MIT License
643 stars 114 forks source link

BUG: adding graphormer graph token breaks node permutation invariance #40

Open isefos opened 9 months ago

isefos commented 9 months ago

I was experimenting with the graphormer model, specifically for graph classification using the virtual node for global pooling (graph_pooling: graph_token).

Problem

I noticed that the model was producing different outputs for the same input graph with permuted node order. The problem should be easy to replicate, here is an example:

import torch
from torch_geometric.data import Batch

# given some data batch, e.g. inside the training loop
# create a copy of the first graph
data = Batch.from_data_list([batch.get_example(0).clone()])
data_p = Batch.from_data_list([batch.get_example(0).clone()])

# and permute the nodes: 
# here we simply put the previously last node in first place of the first graph
n = data_p.x.size(0)
p = torch.arange(n, dtype=torch.long) - 1
p[0] = n - 1
data_p.x = data_p.x[p]
assert (data_p.x[0, :] == data.x[-1, :]).all()
assert (data_p.x[1:, :] == data.x[:-1, :]).all()

# make sure to permute the other node features as well
data_p.batch = data_p.batch[p]
data_p.in_degrees = data_p.in_degrees[p]
data_p.out_degrees = data_p.out_degrees[p]

# and change the indices accordingly (all increase by one, just the last one gets set to zero)
n = data_p.x.size(0)
data_p.edge_index += 1
data_p.edge_index[data_p.edge_index == n] = 0
data_p.graph_index += 1
data_p.graph_index[data_p.graph_index == n] = 0

# then get the model outputs for each graph
model.eval()
with torch.no_grad():
    output, _ = model(data)
    output_p, _ = model(data_p)

# check if outputs are equal
assert torch.allclose(output, output_p), "Permuted graph produces different output!"

This is unexpected (and worrisome) behavior. In theory, the model architecture should be invariant to such changes, as should any GNN.

Cause

The cause turned out to be in the add_graph_token function, in this line:

data.batch, sort_idx = torch.sort(data.batch)
data.x = data.x[sort_idx]

torch.sort is called to get all the newly concatenated virtual nodes neatly grouped together with their respective other batch nodes.

But it is called without the argument stable, which means the default stable=False is used. As a result the indices inside each graph (same batch index) don't stay in the same order as before. Rather, each graph gets its nodes permuted by the sorting algorithm. This by itself would not necessarily be a problem, as the model should be invariant to such permutations. However, all the indices used in the other data attributes (edge_index, in_degrees, att_bias, etc.) are still referencing the old node order and should then also get permuted/ remapped.

Fix

Of course the much simpler solution is to simply use the stable sorting, and change the line to:

data.batch, sort_idx = torch.sort(data.batch, stable=True)

When running the example from above again with this change the outputs are now indeed the same!

I haven't done any testing yet on how this bug fix affects the training and classification performance, but I could imagine that being node permutation invariant, and not having the node features "randomly" permuted would make things a bit easier for the model...

migalkin commented 9 months ago

Did you also permute positional encodings (on both node and edge level) that are needed for Graphormer?

isefos commented 9 months ago

Yes, I believe so, at least for the case of using the "normal" graphormer preprocessing and encoder. The positional encoding on node level would be the in and out degrees, which I permute for my example above with:

data_p.in_degrees = data_p.in_degrees[p]
data_p.out_degrees = data_p.out_degrees[p]

And for the edge level the encodings are stored "sparsely" (spatial_types and shortest_path_types indexed by graph_index, analogously to edge_attr with edge_index). So to reflect the node permutation in the edges, I ensured that the index arrays get remapped to the new node labels. In my (very specific) given example this is accomplished by:

n = data_p.x.size(0)
data_p.edge_index += 1
data_p.edge_index[data_p.edge_index == n] = 0
data_p.graph_index += 1
data_p.graph_index[data_p.graph_index == n] = 0

The example data I am using does not have edge_attr and therefore also no shortest_path_types, but it should still work correctly for data with edge attributes using the above example (because of the change to edge_index and graph_index).

Also, the fact that the model outputs are the same for the original and permuted graph when using stable sorting indicates that the positional encodings were probably permuted correctly as well...

luis-mueller commented 9 months ago

@isefos Thank you a lot for raising this issue. Indeed, this seems to be a bug. Further, I believe that the stable sorting is exactly what we need here. I will test this on my side and also run a few experiments to see whether there is any (positive or negative) impact on performance.