Open isefos opened 1 year ago
Did you also permute positional encodings (on both node and edge level) that are needed for Graphormer?
Yes, I believe so, at least for the case of using the "normal" graphormer preprocessing and encoder. The positional encoding on node level would be the in and out degrees, which I permute for my example above with:
data_p.in_degrees = data_p.in_degrees[p]
data_p.out_degrees = data_p.out_degrees[p]
And for the edge level the encodings are stored "sparsely" (spatial_types
and shortest_path_types
indexed by graph_index
, analogously to edge_attr
with edge_index
). So to reflect the node permutation in the edges, I ensured that the index arrays get remapped to the new node labels. In my (very specific) given example this is accomplished by:
n = data_p.x.size(0)
data_p.edge_index += 1
data_p.edge_index[data_p.edge_index == n] = 0
data_p.graph_index += 1
data_p.graph_index[data_p.graph_index == n] = 0
The example data I am using does not have edge_attr
and therefore also no shortest_path_types
, but it should still work correctly for data with edge attributes using the above example (because of the change to edge_index
and graph_index
).
Also, the fact that the model outputs are the same for the original and permuted graph when using stable sorting indicates that the positional encodings were probably permuted correctly as well...
@isefos Thank you a lot for raising this issue. Indeed, this seems to be a bug. Further, I believe that the stable sorting is exactly what we need here. I will test this on my side and also run a few experiments to see whether there is any (positive or negative) impact on performance.
I was experimenting with the graphormer model, specifically for graph classification using the virtual node for global pooling (
graph_pooling: graph_token
).Problem
I noticed that the model was producing different outputs for the same input graph with permuted node order. The problem should be easy to replicate, here is an example:
This is unexpected (and worrisome) behavior. In theory, the model architecture should be invariant to such changes, as should any GNN.
Cause
The cause turned out to be in the
add_graph_token
function, in this line:torch.sort
is called to get all the newly concatenated virtual nodes neatly grouped together with their respective other batch nodes.But it is called without the argument
stable
, which means the defaultstable=False
is used. As a result the indices inside each graph (same batch index) don't stay in the same order as before. Rather, each graph gets its nodes permuted by the sorting algorithm. This by itself would not necessarily be a problem, as the model should be invariant to such permutations. However, all the indices used in the other data attributes (edge_index
,in_degrees
,att_bias
, etc.) are still referencing the old node order and should then also get permuted/ remapped.Fix
Of course the much simpler solution is to simply use the stable sorting, and change the line to:
When running the example from above again with this change the outputs are now indeed the same!
I haven't done any testing yet on how this bug fix affects the training and classification performance, but I could imagine that being node permutation invariant, and not having the node features "randomly" permuted would make things a bit easier for the model...