In the original implementation of the GPSLayer (found in graphgps/layer/gps_layer.py), structural encodings for edges are used in the multi-head attention block. Specifically, structural encodings are encoded in batch.attn_bias as a "real-valued" adjacency matrix. However, the PyTorch Geometric (PyG) implementation of GPS does not currently support edge structural encodings. Implementing this option should be straightforward, requiring only a few modifications:
Check if the parameter attn_bias has been provided by the user in the forward call.
If attn_bias is provided and multi-head attention has been selected as the transformer layer, add attn_mask=attn_bias when calling self.attn.
This enhancement would align the PyG implementation more closely with the original GPS Layer, allowing for the use of structural encodings in edge attention calculations. The only problem I see is how to implement batching of real adjacency matrices (I am a relatively new user of PyTorch geometric and I do not know how to extend properly the data class).
π The feature, motivation and pitch
In the original implementation of the GPSLayer (found in graphgps/layer/gps_layer.py), structural encodings for edges are used in the multi-head attention block. Specifically, structural encodings are encoded in
batch.attn_bias
as a "real-valued" adjacency matrix. However, the PyTorch Geometric (PyG) implementation of GPS does not currently support edge structural encodings. Implementing this option should be straightforward, requiring only a few modifications:attn_bias
has been provided by the user in the forward call.attn_bias
is provided and multi-head attention has been selected as the transformer layer, addattn_mask=attn_bias
when callingself.attn
.This enhancement would align the PyG implementation more closely with the original GPS Layer, allowing for the use of structural encodings in edge attention calculations. The only problem I see is how to implement batching of real adjacency matrices (I am a relatively new user of PyTorch geometric and I do not know how to extend properly the data class).
Alternatives
No response
Additional context
No response