pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.21k stars 3.64k forks source link

Graph Pooling Framework :rocket: #6455

Closed rusty1s closed 1 year ago

rusty1s commented 1 year ago

🚀 The feature, motivation and pitch

The "Understanding Pooling in Graph Neural Networks" paper introduces a simple framework to unify (most) graph pooling approaches:

  1. Select: Selects input nodes to map to supernodes
  2. Reduce: Reduces the supernodes to singletons
  3. Connect: Decides how the new nodes are connected

Currently, the pooling operators included in PyG are implemented isolated from each other, and there exists a lot of repetitive code and different interfaces, which makes their application confusing and challenging. By following the above approach, we can unify existing implementations and can accelerate new research on graph pooling. For example, we can introduce base classes for each of the aforementioned steps:

class Select(torch.nn.Module):
    def forward(self, *args, **kwargs) -> torch.Tensor:
         """Returns a bipartite `edge_index` mapping input nodes to supernodes."""

class Reduce:  # No need -> we can utilize `nn.aggr` for this
    pass

class Connect(torch.nn.Module):
    def forward(self, cluster_index: torch.Tensor, edge_index: torch.Tensor, *args, **kwargs):
        """Returns a coarsened graph."""

With this, e.g., graclus can be moved to a Select operator and TopK pooling can be moved to a Connect operator.

Relevant twitter thread: https://twitter.com/riceasphait/status/1447867635442585601

Tasks:

Add concrete Select and Connect classes and update implementations of.

wsad1 commented 1 year ago

Would it be useful to add a Pooling class that all pooling operators inherit from.

class Pooling(torch.nn.Module):
    def forward(self, x, edge_index, *args, **kwargs):
          mapping = self.select(x, edge_index, *args, **kwargs)
          x = self.reduce(x, mapping)
          edge_index = self.connect(mapping, edge_index, *args, **kwargs)

self.select, reduce and connect would be objects defined in the child class like graclus or TopK.

rusty1s commented 1 year ago

Yes, definitely :)

danielegrattarola commented 1 year ago

Hey, first author of the paper here!

This is super, if help is still wanted (as I see on the tags) I'd be happy to participate!

Worth mentioning that there is an implementation of SRC in Spektral and a few layers implemented with it, it might be possible to translate the implementation from TF to Torch/PyG.

wsad1 commented 1 year ago

@danielegrattarola firstly great paper. Sorry for the late reply. But if you are still interested we could use your help moving SAGPooling and PANPooling. Just refactor topkpooling here.

puririshi98 commented 1 year ago

benchmark of topK function

Screen Shot 2023-06-07 at 11 18 14 AM

https://github.com/puririshi98/rgcn_pyg_lib_forward_bench/blob/main/topK_bench.py

puririshi98 commented 1 year ago

topk microbench PR: https://github.com/pyg-team/pytorch_geometric/pull/7549

puririshi98 commented 1 year ago

diff pool microbench PR: https://github.com/pyg-team/pytorch_geometric/pull/7550

puririshi98 commented 1 year ago

https://github.com/pyg-team/pytorch_geometric/pull/7361 PR is ready to merge

puririshi98 commented 1 year ago

https://github.com/pyg-team/pytorch_geometric/pull/7625