rampasek / GraphGPS

Recipe for a General, Powerful, Scalable Graph Transformer
MIT License
643 stars 114 forks source link

Constructing custom GT+GNN+POSENC object on own graph #16

Closed Jotels closed 1 year ago

Jotels commented 1 year ago

Hey

I am currently working on a problem where I would like to try out different sets of PE/SE encodings combined with different types of GT and GNNs. From the wording of the GraphGPS paper and TDS post, I half expected that there would be a very clear way of constructing one or several objects that altogether would make up such a model.

In my specific case, I have tensors containing a predefined set of node features and node-node indices indicating edges.

I see that you provide config files specific to each of your datasets, and I see the object components that make up the final GPS model in each case. What I don't understand is how this would all be put together.

Say I wanted to construct an arbitrary model according to the recipe provided in the paper, how would I go about this given my node features and edge indices?

rampasek commented 1 year ago

Hi,

The main model is here: https://github.com/rampasek/GraphGPS/blob/main/graphgps/network/gps_model.py and particularly the GPS layer implementation is here: https://github.com/rampasek/GraphGPS/blob/main/graphgps/layer/gps_layer.py The GPS layer already allows to be configured with several types of MPNNs and Global attention modules. Adding another MPNN type should be straightforward from there.

If you want to create new types of general PE/SEs then follow the examples of the LapPE, RWSE and other 3 encodings (I can provide more guidance later if this is what you are after).

But it sounds like you have a custom dataset and you are primarily interested in how to plug it into the GPS pipeline. For that look at the main loader function: https://github.com/rampasek/GraphGPS/blob/6305368446275f9d5f3736d44bf265214d4f9a9b/graphgps/loader/master_loader.py#L82 And the related per-dataset formatting functions in the file. Examples of several custom PyG dataset classes are then in https://github.com/rampasek/GraphGPS/tree/main/graphgps/loader/dataset -- there is nothing GPS-specific when it comes to the dataset classes, these are following PyG dataset classes from OGB or core PyG. See https://pytorch-geometric.readthedocs.io/en/latest/notes/create_dataset.html

I hope this helps! Ladislav

Jotels commented 1 year ago

Hi Ladislav

Thanks for getting back to me so quickly. I am, however, still not entirely sure how I would go about this

To elaborate a bit more, I am running a reinforcement learning setup where the actions reconfigure a simulated physical system with a graph representation. Thus, the graph representation changes when the object is reconfigured.

My setup looks something like this:

  1. System is represented as a graph using a predefined node embedding function (an atomic embedding). This gives me all the basic node/edge features and edge indices.
  2. Add PE/SE features <-- (Would like to use GPS pipeline here)
  3. Apply MPNN layers <-- (Would like to use GPS pipeline here)
  4. Apply Attention layers <-- (Would like to use GPS pipeline here)
  5. Use a readout layer to output a probability distribution and a state value.

Since I already have my fundamental node embeddings and connections from step 1, I currently am able to call something like

updated_nodes, updated_edges = message_and_transform(node_features, edge_indices, edge_features)

Where 'message_and_transform' takes my basic node features, adds the PE/SE features (currently the basic DGL version of LapPE and RWPE), does the message passing and applies transformer layers. My implementation is rather basic, however, and so ideally I would like to change this message_and_transform module to be a customizable version of the GPS pipeline in order to test out different combinations of pos_encoding+MPNN+GT.

Do you have any further suggestions?

rampasek commented 1 year ago

Hi,

I think you'll need to do several custom adjustments to turn it into what you need, but the key components should be possible to reuse in your own project's codebase.

From what I understand the main part that you'll need to implement is dynamically building a batch of your current graphs, apply PE/SE and then pushing it through the rest of GPS should be straightforward.

Here is the function that precomputes PE/SE stats: https://github.com/rampasek/GraphGPS/blob/main/graphgps/transform/posenc_stats.py It is applied to every graph in a dataset here: https://github.com/rampasek/GraphGPS/blob/8b9309ce6ad5c56ecb15ceda2887aa6ee65eb922/graphgps/loader/master_loader.py#L187

I hope this helps at least a little bit, your project is very interesting, but this use case is outside of the direct applicability of this repo. Good luck on your project! Ladislav