pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.31k stars 3.65k forks source link

Addition of Mean Substraction Operator from "Revisiting “Over-smoothing” in Deep GCNs" #5056

Closed Alvaro-Ciudad closed 2 years ago

Alvaro-Ciudad commented 2 years ago

🚀 The feature, motivation and pitch

I have been working on graph autoencoders, and I had some problems with oversmoothing. I ve tried a few alternatives, and this is one of the best working ones. I also believe that PyG could benefit of more alternative methods of tackling oversmoothing. The code of the layer is already done, so it would just be a simple pull request with a few tests.

from torch_scatter import scatter
class MeanSubstraction(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, batch = None):

        if batch is None:
            x = x - x.mean(dim=0, keepdim=True)
            return x
        else:
            mean = scatter(x, batch, dim=0, reduce='mean')
            x = x - mean.index_select(0, batch)
            return x

Alternatives

It could also be added as a flag inside of the PairNorm implementation, as it is an special case of this normalization.

Additional context

The paper in question: https://arxiv.org/pdf/2003.13663v1.pdf

rusty1s commented 2 years ago

Thanks for sharing. This looks super easy to integrate within the nn.norm package. Let me know if you want to work on this. @lightaime and @Padarn can help you further.

Padarn commented 2 years ago

Sure I'll pick it up!

Alvaro-Ciudad commented 2 years ago

Great, if you pick it up works for me :)

lightaime commented 2 years ago

Do we want to add it to the nn.aggr. It looks more like something we should add to nn.norm to me. It is a part of PairNorm as @Alvaro-Ciudad mentioned.

rusty1s commented 2 years ago

Oh, you are right. I totally misread.

Padarn commented 2 years ago

I also misread this (sorry I was on my phone)... I made a PR here https://github.com/pyg-team/pytorch_geometric/pull/5068, but happy for @Alvaro-Ciudad to take over and make any changes that make sense to you.

I added it to the nn.norm package, but added support for using any an aggregator from the new nn.aggr.