Open adrianjav opened 1 year ago
@adrianjav ss I also desire the feature of "weighted" aggregations, I could take over this part or wait for you. Both options are fine with me.
@sigeisler sure, I'd be more than happy if you can take care of that part :)
What are your thoughts on the suggested changes @rusty1s ?
I personally like the idea of moving this to its own module. There actually exists a lot of duplicated code across modules and unifying them would be very welcome. If I understand the proposal correctly, this would boil down to:
propagate
|
edge_updater -> edge_index, edge_weight
|
message
|
aggregate
This sounds quite powerful, especially with some potential composition support for edge_updater
, e.g., such as edge_updater = Compose([AddSelfLoops(), GCNNorm(), Cache()])
.
I am still a bit unsure about the interface of edge_updater
though - just letting it return a new edge_weight
feels kinda limited to me - in the end, edge_updater
also needs to be able to update edge_index
(e.g., for adding self-loops), and it needs to be able to return an arbitrary set of edge features. I am also unsure about the interplay of edge_updater
and message
. How would we make the output of edge_updater
accessible in message
?
Glad to hear you like the idea! And yes, you understood it perfectly.
I'd propose the following (let me know what you think):
edge_updater
, we have different options:
edge_index, edge_weight = edge_updater(edge_index, edge_weight)
(same interface as gcn_norm
right now), and leave edge_features
to each specific method. My reasoning here is that while the edges and their "strength" (aka. weight) is intrinsic to the graph, their features are problem-dependent and hard to infer. This is my personal preference.edge_index, edge_weight, edge_attr = edge_updater(edge_index, edge_weight, edge_attr)
similar to add_self_loops
. While more flexible, I am not sure how helpful it would be. I checked and add_self_loops
to infer edge_attr
is never used in the codebase (and used once to infer edge_weight
)message
interface, edge_index, edge_weight, others = edge_update(edge_index, edge_weight, **kwargs)
to let the user specify custom edge updaters.message
, my suggestion would be that by default everything regarding edge weights is done orthogonally to message
through the aggregator. However, if the user wants to access this data, they could just specify edge_weight
or edge_attr
in the message
interface and __collect__
(as it is right now) should provide the correct data (unless I misunderstood the code).I have a follow-up question @rusty1s: is there any reason to distinguish between edge_index
and edge_weight
? Wouldn't it be way simpler to leave edge_weight
alone and obtain edge_index
via edge_weight != 0.0
? I am most familiar with the code in nn.conv
, so maybe there is something I am missing.
Mh, maybe we don't need to overthink the interface right now and just start with some modules that perform edge updates and that can be directly plugged into existing GNN layers, e.g., GCNNorm
or Attention
, and then think about how to make this available as part of propagate
? I personally think right now that is non-trivial to integrate this into propagate
since the mapping from the outputs of edge_updater
to the argument names of message
is unspecified (or at least it is not clear to me yet how you would want to do it).
is there any reason to distinguish between edge_index and edge_weight?
How would this work without this differentiation? In that case, we would need to work on a dense adjacency matrix, right?
I still struggle to see how the integration within propagate
is non-trivial. We can already access to edge_index
and edge_weight
through the message function, couldn't we just update those variables with edge_update
before calling __collect__
?
In any case, I agree that going step-by-step is the best approach. Yesterday I updated my fork and started implementing changes, I will keep you updated. For now, I started implementing GCNNorm
and making sure that GCNConv
does not break.
To my eyes, the most significant amount of effort will go to updating existing classes and the documentation, but I could be wrong.
How would this work without this differentiation? In that case, we would need to work on a dense adjacency matrix, right?
You are completely right, I abstracted too much and forgot the actual format of the indexes. My bad!
I still struggle to see how the integration within propagate is non-trivial. We can already access to edge_index and edge_weight through the message function, couldn't we just update those variables with edge_update before calling collect?
Yes, but this would only work if we would rely on edge_updater
only updating edge_weight
, right? For example, what about the following:
def forward(self, x, edge_index, edge_weight, edge_attr1, edge_attr2):
edge_attr2 = self.edge_updater(edge_index, edge_attr2)
self.propagate(edge_index, edge_weight=edge_weight, edge_attr=edge_attr, edge_attr2=edge_attr2)
Yes, that would work if we know the variables we want to update. So either:
edge_attr
and edge_attr2
for some reason, they can concatenate them in a single tensor.What do you think?
Yeah, you are probably right, no real reason to overthink this. My vote would go to option 2.
Cool, then I will keep working on my fork and see whether I can write the interface cleanly.
So far, I am trying to replace the gcn_norm call in GCNConv
with a call to an EdgeUpdater
object of the form
Compose((
AddRemainingSelfLoops(fill_weight_value=2 if self.improved else 1),
GCNNorm(self.improved),
))
and I am having my first headaches already 😂 It looks like I need to update the add_self_loops
functionality to deal with SparseTensor
as well. I managed to pass all tests but that one for obvious reasons.
I'll come back if I manage to find time and make some more progress.
Yeah, that one is a bit tricky to deal with. I think I can handle this on our end. If you can make a PR work with pure edge_index
support, that would be sufficient.
I just created a draft for the PR. Regarding what we were discussing here, moving the functionality to propagate
really simplifies the code (unless I have missed something).
🛠Proposed Refactor
We (@psanch21 and I) have been thinking on how to seamless integrate our last work (LCAT) into PyG. In short, we propose a way of combining GAT-like networks with GCN-like networks. However, we have struggled to find an open-closed implementation that frees us from implementing one LCAT implementation per existing GCN model.
One key point that we make in the paper (Eq. 13 of Appendix F) is that attention can be seen as a way of assigning weights different from 1 to the adjacency matrix of the graph, and that we can integrate it with other GCN models if we see the usual formulation: to include the edge weights: And, looking at #6867, we could go one step further and write it as: were I have included the point x_i in the aggregator function, which now takes the edge weights as another argument.
From this point of view, and looking into the GNN Cheatsheet, I have realized that PyG could benefit a lot from a simple refactor.
Suggest a potential alternative/fix
These are the API refactor I would propose:
edge_updater
fromMessagePassing
(it is only used by GATConv).edge_update
at the beginning ofpropagate
instead to compute the weights of the edges.EdgeUpdater
similar toAggregation
that can be passed to a conv layer.This would imply the following benefits (some need to adapt parts of the code):
GCNConv
andGCN2Conv
) and it could be used with any layer.GCNConv
we could implement a classAttention
which inherits fromEdgeUpdater
and abstractly implement attention layers likeGAT
,GAT2
, andSuperGAT
(removing the weight matrices and biases that multiply the inputs).GATConv
, for example, would simply be aGCNConv
withGAT
as edge updater and an aggregator to take the attention heads into account.PNAConv
) with an attention module (e.g.,GAT2
), as long as we pass the properEdgeUpdater
andAggregation
instances.If green-lighted, I can do myself the suggested changes and push a PR, as well as introduce our LCAT model in PyG (in a different PR I suppose).