lukecavabarrett / pna

Implementation of Principal Neighbourhood Aggregation for Graph Neural Networks in PyTorch, DGL and PyTorch Geometric
https://arxiv.org/abs/2004.05718
MIT License
338 stars 55 forks source link

Simpler version of the PNA #3

Closed DomInvivo closed 4 years ago

DomInvivo commented 4 years ago

I found for some tasks that I'm working on personally that the MPNN-style architecture does not perform well, no matter the aggregators or scalers that are used. Even the simple sum and mean aggregators perform less well than their GIN and GCN cousins.

For this reason, I propose to add the following simpler architecture as a variant of the PNA layer, which doesn't use the MPNN attention mechanism, but instead aggregates the neighbours in a similar way than CNN, GCN and GIN layers. An obvious drawback is the lack of edge features, but on my personal project on a molecular dataset, edge features seem to cause more overfit.

I propose to add it in the file pna/models/dgl/pna_layer.py. I did not implement it in pytorch-geometric or standard pytorch.

class PNASimpleLayer(nn.Module):

    def __init__(self, in_dim, out_dim, aggregators, scalers, avg_d, dropout, batch_norm, activation,
                posttrans_layers=1, residual=False):
        """
        A PNA layer that simply aggregates the neighbourhood (similar to GCN and GIN),
        without using the attention mechanism of the MPNN. It does not support edge features.

        :param in_dim:              size of the input per node
        :param out_dim:             size of the output per node
        :param aggregators:         set of aggregation function identifiers
        :param scalers:             set of scaling functions identifiers
        :param avg_d:               average degree of nodes in the training set, used by scalers to normalize
        :param dropout:             dropout used
        :param batch_norm:          whether to use batch normalisation
        :param posttrans_layers:    number of layers in the transformation after the aggregation
        """
        super().__init__()

        # retrieve the aggregators and scalers functions
        aggregators = [AGGREGATORS[aggr] for aggr in aggregators.split()]
        scalers = [SCALERS[scale] for scale in scalers.split()]

        self.aggregators = aggregators
        self.scalers = scalers
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.dropout = dropout
        self.batch_norm = batch_norm

        self.batchnorm_h = nn.BatchNorm1d(out_dim)
        self.activation = activation
        self.posttrans = MLP(in_size=(len(aggregators) * len(scalers) + 1) * in_dim, hidden_size=out_dim,
                            out_size=out_dim, layers=posttrans_layers, 
                            mid_activation=activation, last_activation=activation,
                            dropout=dropout, mid_b_norm=batch_norm, last_b_norm=batch_norm)
        self.avg_d = avg_d

    def reduce_func(self, nodes):
        h = nodes.mailbox['m']
        D = h.shape[-2]
        h = torch.cat([aggregate(h) for aggregate in self.aggregators], dim=1)
        h = torch.cat([scale(h, D=D, avg_d=self.avg_d) for scale in self.scalers], dim=1)
        return {'h': h}

    def forward(self, g, h):
        g.ndata['h'] = h

        # aggregation
        g.update_all(fn.copy_u('h', 'm'), self.reduce_func)
        h = torch.cat([h, g.ndata['h']], dim=1)

        # posttransformation
        h = self.posttrans(h)

        return h

    def __repr__(self):
        return '{}(in_channels={}, out_channels={})'.format(self.__class__.__name__, self.in_dim, self.out_dim)
sooheon commented 4 years ago

Out of curiosity, can you elaborate on the type of dataset you have found this to be more effective on? # molecules, information about task, etc.

DomInvivo commented 4 years ago

I am working on a molecular dataset for mRNA signature similarity matching, with the dataset provided by the following GitHub: deepSIBA.

After careful parameter optimization, I found that the simpler architecture performs similarly to the MPNN-based architecture on this dataset but with considerably fewer features since the towers are not required. I also found that edge features decreased model performance, which I find very weird.

gcorso commented 4 years ago

Added, thank you very much Dominique!