pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.48k stars 3.68k forks source link

Upgrade edge_weight to first-class citizens 🦾 #6871

Open adrianjav opened 1 year ago

adrianjav commented 1 year ago

🛠 Proposed Refactor

We (@psanch21 and I) have been thinking on how to seamless integrate our last work (LCAT) into PyG. In short, we propose a way of combining GAT-like networks with GCN-like networks. However, we have struggled to find an open-closed implementation that frees us from implementing one LCAT implementation per existing GCN model.

One key point that we make in the paper (Eq. 13 of Appendix F) is that attention can be seen as a way of assigning weights different from 1 to the adjacency matrix of the graph, and that we can integrate it with other GCN models if we see the usual formulation: image to include the edge weights: image And, looking at #6867, we could go one step further and write it as: image were I have included the point x_i in the aggregator function, which now takes the edge weights as another argument.

From this point of view, and looking into the GNN Cheatsheet, I have realized that PyG could benefit a lot from a simple refactor.

Suggest a potential alternative/fix

These are the API refactor I would propose:

This would imply the following benefits (some need to adapt parts of the code):

If green-lighted, I can do myself the suggested changes and push a PR, as well as introduce our LCAT model in PyG (in a different PR I suppose).

sigeisler commented 1 year ago

@adrianjav ss I also desire the feature of "weighted" aggregations, I could take over this part or wait for you. Both options are fine with me.

adrianjav commented 1 year ago

@sigeisler sure, I'd be more than happy if you can take care of that part :)

adrianjav commented 1 year ago

What are your thoughts on the suggested changes @rusty1s ?

rusty1s commented 1 year ago

I personally like the idea of moving this to its own module. There actually exists a lot of duplicated code across modules and unifying them would be very welcome. If I understand the proposal correctly, this would boil down to:

propagate
       |
edge_updater -> edge_index, edge_weight
       | 
message
       |
aggregate

This sounds quite powerful, especially with some potential composition support for edge_updater, e.g., such as edge_updater = Compose([AddSelfLoops(), GCNNorm(), Cache()]).

I am still a bit unsure about the interface of edge_updater though - just letting it return a new edge_weight feels kinda limited to me - in the end, edge_updater also needs to be able to update edge_index (e.g., for adding self-loops), and it needs to be able to return an arbitrary set of edge features. I am also unsure about the interplay of edge_updater and message. How would we make the output of edge_updater accessible in message?

adrianjav commented 1 year ago

Glad to hear you like the idea! And yes, you understood it perfectly.

I'd propose the following (let me know what you think):

I have a follow-up question @rusty1s: is there any reason to distinguish between edge_index and edge_weight? Wouldn't it be way simpler to leave edge_weight alone and obtain edge_index via edge_weight != 0.0? I am most familiar with the code in nn.conv, so maybe there is something I am missing.

rusty1s commented 1 year ago

Mh, maybe we don't need to overthink the interface right now and just start with some modules that perform edge updates and that can be directly plugged into existing GNN layers, e.g., GCNNorm or Attention, and then think about how to make this available as part of propagate? I personally think right now that is non-trivial to integrate this into propagate since the mapping from the outputs of edge_updater to the argument names of message is unspecified (or at least it is not clear to me yet how you would want to do it).

is there any reason to distinguish between edge_index and edge_weight?

How would this work without this differentiation? In that case, we would need to work on a dense adjacency matrix, right?

adrianjav commented 1 year ago

I still struggle to see how the integration within propagate is non-trivial. We can already access to edge_index and edge_weight through the message function, couldn't we just update those variables with edge_update before calling __collect__?

In any case, I agree that going step-by-step is the best approach. Yesterday I updated my fork and started implementing changes, I will keep you updated. For now, I started implementing GCNNorm and making sure that GCNConv does not break.

To my eyes, the most significant amount of effort will go to updating existing classes and the documentation, but I could be wrong.

How would this work without this differentiation? In that case, we would need to work on a dense adjacency matrix, right?

You are completely right, I abstracted too much and forgot the actual format of the indexes. My bad!

rusty1s commented 1 year ago

I still struggle to see how the integration within propagate is non-trivial. We can already access to edge_index and edge_weight through the message function, couldn't we just update those variables with edge_update before calling collect?

Yes, but this would only work if we would rely on edge_updater only updating edge_weight, right? For example, what about the following:

def forward(self, x, edge_index, edge_weight, edge_attr1, edge_attr2):
    edge_attr2 = self.edge_updater(edge_index, edge_attr2)

    self.propagate(edge_index, edge_weight=edge_weight, edge_attr=edge_attr, edge_attr2=edge_attr2)
adrianjav commented 1 year ago

Yes, that would work if we know the variables we want to update. So either:

  1. We restrict edge_update to update only edge_index and edge_weight.
  2. We restrict edge_update to update only edge_index, edge_weight, and edge_attr. At the end, those are the variables defined in the cheat sheet. If the user wants to have edge_attr and edge_attr2 for some reason, they can concatenate them in a single tensor.
  3. We let the method return an additional dictionary instead to know exactly which variables they are returning.

What do you think?

rusty1s commented 1 year ago

Yeah, you are probably right, no real reason to overthink this. My vote would go to option 2.

adrianjav commented 1 year ago

Cool, then I will keep working on my fork and see whether I can write the interface cleanly.

So far, I am trying to replace the gcn_norm call in GCNConv with a call to an EdgeUpdater object of the form

Compose((
    AddRemainingSelfLoops(fill_weight_value=2 if self.improved else 1),
    GCNNorm(self.improved),
))

and I am having my first headaches already 😂 It looks like I need to update the add_self_loops functionality to deal with SparseTensor as well. I managed to pass all tests but that one for obvious reasons.

I'll come back if I manage to find time and make some more progress.

rusty1s commented 1 year ago

Yeah, that one is a bit tricky to deal with. I think I can handle this on our end. If you can make a PR work with pure edge_index support, that would be sufficient.

adrianjav commented 1 year ago

I just created a draft for the PR. Regarding what we were discussing here, moving the functionality to propagate really simplifies the code (unless I have missed something).