pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.21k stars 3.64k forks source link

How to implement other GraphSAGE Aggregator functions like Max Pool? #1147

Closed sbonner0 closed 2 years ago

sbonner0 commented 4 years ago

Thanks for all the great work on the package and sorry if this has been asked before.

I was wondering if it possible to use different aggregation functions, like Max-Pool or LSTM aggregation from the original GraphSAGE paper, with the existing SAGEConv layer? Looking at the documentation, it seems that only the mean aggregation is available.

rusty1s commented 4 years ago

Max-Pool is indeed possible, but we do not have an implementation for it. Feel free to add it if you like. A fast LSTM aggregation is currently possible, but I am working on it.

sbonner0 commented 4 years ago

Hi, thanks so much for the reply. I would love to add it if I am able. Would it be possibe to implemet it how it is in the paper (with a second paramter matrix W_pool) using the current SAGEConv layer? Or would I need to implement it from scratch using MessagePassing? Would you have any tips on where to start?

rusty1s commented 4 years ago

We should maybe first come up with a separate implementation and then decide how to merge the different variants. I suggest to copy most of the SAGEConv layer as a starting point. And yes, you need to introduce W_pool for the max aggregation.

sbonner0 commented 4 years ago

Thanks - I'll hopefully get time to look at this over the next week.

cluel01 commented 4 years ago

Hey guys, is there an update on the two aggregator functions ? Would be keen to know.

sbonner0 commented 4 years ago

Hey - I am still planning to have a go at this as soon as possible, I was inundated with work since the lock down. If you have any idea, or would like to help, I would be most grateful.

jchaykow commented 4 years ago

It looks like maybe you just have to add an additional function here maybe called scatter_maxpool or something?

rusty1s commented 4 years ago

torch-scatter does already support max reduction :)

vijay2411 commented 4 years ago

Hi, Is there any update on max Pool aggregator? I would love to know. Also I think, can we not just change the aggregator value by letting pass an extra argument in SAGEConv constructor of torch_geometric/nn/conv/sage_conv.py ? Is there anything else that needs to be done? The aggregator value there is hard-coded as 'add'. Thanks

Update 1: Okay, somehow my jupyter notebook still shows the previous implementation of Sage_conv.py, I checked the new code and see that the custom aggregator has been added, how are we supposed to pass it as an argument? just like aggr = 'max' right? Error reports

TypeError: init() got multiple values for keyword argument 'aggr', Line 36 sage_conv.py

Update 2: Okay, I think I need to reinstall torch_geometric in python.

Update 3: Okay, I am Unable to install the latest code! What to do????? I have uninstall and reinstall all binaries and the pytorch_geometric repo(with --no-cache-dir flag) ?? It says PyG 1.6.1 installed, but I still can't see this update.

HELP!!!!!! :/

Very Very Thanks

rusty1s commented 4 years ago

You can install from source:

pip install git+https://github.com/rusty1s/pytorch_geometric.git

or change the aggr flag after initialization:

conv = SAGEConv(...)
conv.aggr = 'max'
Curt-Park commented 3 years ago

@rusty1s

The max pooling suggested by the paper is the following: image

conv = SAGEConv(...)
conv.aggr = 'max'

However, this way you mentioned seems like reducing the dimension by max before the affine operation in the formulation above. If I would like to do the way described in the paper, should I set conv.aggr = None and use feature-wise max operator?

rusty1s commented 3 years ago

You can simply apply self.lin_l before self.propagate.

jsun57 commented 3 years ago

Hi, I am wondering if there are any updates regarding the LSTM aggregator for SAGEConv. Thanks a lot.

rusty1s commented 3 years ago

Not yet, and it has very low priority since LSTM aggregation isn't that useful for most GNNs. Nonetheless, you can easily implement this by yourself by overriding the aggregate function in MessagePassing. Given that your edge_indices are sorted column-wise, you can do:

def aggregate(self, inputs, index):
    out, mask = to_dense_batch(inputs, index)  # [num_nodes, num_neighbors, num_features]
    return self.lstm(out)
hunarbatra commented 2 years ago

For Max Pooling - would it be right to use: out = matmul(adj_t, x[0], reduce='max') ?

(along with applying: self.lin_l before self.propagate to match the equation in the paper)

@rusty1s

hunarbatra commented 2 years ago

This is my code below for SAGEConv. Also, LSTM isn't working for me. Please could you tell me how to implement it or how to fix it in my implementation. Thank you so much! :)

class SAGETest2(MessagePassing):
    def __init__(self, in_channels: Union[int, Tuple[int, int]],
                 out_channels: int, aggregator_type: str, normalize: bool = False,
                 root_weight: bool = True, bias: bool = True):
        # kwargs.setdefault('aggr', 'lstm')
        super(SAGETest2, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.normalize = normalize
        self.root_weight = root_weight
        self.aggregator_type = aggregator_type
        self.bias = bias

        if isinstance(in_channels, int):
            in_channels = (in_channels, in_channels)

        if self.aggregator_type == 'gcn':
          # GCN does not require self root node, but just the neighbours
          self.root_weight = False 

        if aggregator_type == 'max':
          self.lin_pool = Linear(in_channels[0], in_channels[0])

        if self.aggregator_type == 'lstm':
          self.lstm = LSTM(
                in_channels[0], (2 * in_channels[0]) // 2,
                bidirectional=True,
                batch_first=True)
          self.att = Linear(2 * ((2 * in_channels[0]) // 2), 1)

        # if self.aggregator_type == 'lstm':
        #   self.lstm = nn.LSTM(
        #     input_size=in_channels,
        #     hidden_size=out_channels,
        #     batch_first=True,
        #   )

        if self.aggregator_type != 'lstm':
          self.lin_l = Linear(in_channels[0], out_channels, bias=bias) # neighbours
        else:
          self.lin_l = Linear(in_channels[0] + (2 * in_channels[0]), out_channels, bias=bias) # neighbours

        if self.root_weight: # Not created for GCN
            self.lin_r = Linear(in_channels[1], out_channels, bias=False) # root itself

        self.reset_parameters()

    def reset_parameters(self):
        self.lin_l.reset_parameters()
        if self.root_weight:
          self.lin_r.reset_parameters()
        if self.aggregator_type == 'lstm':
          self.lstm.reset_parameters()
        if self.aggregator_type == 'max':
          self.lin_pool.reset_parameters()

    def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
                size: Size = None) -> Tensor:
        """"""
        if isinstance(x, Tensor):
            x: OptPairTensor = (x, x)

        # Determine whether to apply linear transformation before message passing A(XW)
        # in_channels = self.in_channels
        # lin_before_mp = in_channels[0] > self.out_channels
        # if lin_before_mp:
        #   # Apply linear transformation before message passing to match the eqn in the paper
        #   out = self.lin_l(x[0]) # lin_l is for neighbours, and x[0] is the source i.e neighbours
        #   # propagate_type: (x: OptPairTensor)
        #   out = self.propagate(edge_index, x=x, size=size)
        # else:
        #   # propagate_type: (x: OptPairTensor)
        #   out = self.propagate(edge_index, x=x, size=size)
        #   out = self.lin_l(out)

        if self.aggregator_type == 'max':
          out = F.relu(self.lin_pool(x[0]))
          out = self.propagate(edge_index, x=x, size=size)
          out = self.lin_l(out)
        else:
          out = self.propagate(edge_index, x=x, size=size)
          out = self.lin_l(out)

        x_r = x[1] # x[1] -- root
        if self.root_weight and x_r is not None: 
            out += self.lin_r(x_r) # root doesn't get added for gcn

        # bias term
        # if self.bias is not None:
        #     out = out + self.bias

        if self.normalize:
            out = F.normalize(out, p=2., dim=-1)

        return out

    def message(self, x_j: Tensor) -> Tensor:
        return x_j

    # def aggregate(self, inputs, index):
    #   if self.aggregator_type == 'lstm':
    #     out, mask = to_dense_batch(inputs[0], index)  # [num_nodes, num_neighbors, num_features]
    #     return self.lstm(out)

    def message_and_aggregate(self, adj_t: SparseTensor,
                              x: OptPairTensor) -> Tensor:
      adj_t = adj_t.set_value(None, layout=None)
      if self.aggregator_type == 'mean' or self.aggregator_type == 'gcn':
        out = matmul(adj_t, x[0], reduce='mean')

      elif self.aggregator_type == 'max':
        # out = matmul(adj_t, x[0], reduce='max')
        out = torch.stack(x[0], dim=-1).max(dim=-1)[0]

      elif self.aggregator_type == 'lstm':
        # x, (h_n, c_n) = self.lstm(x[0])
        # out = h_n
        x = torch.stack(x[0], dim=1)  # [num_nodes, num_layers, num_channels]
        alpha, _ = self.lstm(x)
        alpha = self.att(alpha).squeeze(-1)  # [num_nodes, num_layers]
        alpha = torch.softmax(alpha, dim=-1)
        return (x * alpha.unsqueeze(-1)).sum(dim=1)
      return out
rusty1s commented 2 years ago

max aggregation should be easily achievable by just setting aggr="max" when instantiating SAGEConv.

Your LSTM implementation looks a bit confusing to me since you are missing the edge-level x_j matrix computation:

def message_and_aggregate(self, adj_t, x, edge_index_j, edge_index_i):
    x_j = x[0][edge_index_j]
    x, mask = to_dense_batch(x_j, edge_index_i)
    x = self.lstm(x)
    ...
hunarbatra commented 2 years ago

Thank you so much! Finally, LSTM is working for me and this is my final SAGEConv: Do you think I should open a PR as it might help others looking for these implementations?

from typing import Tuple, Union
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import LSTM
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.typing import Adj, OptPairTensor, Size
from torch_sparse import SparseTensor, matmul

from torch_geometric.utils import to_dense_batch

class SAGEConv(MessagePassing):
    def __init__(self, in_channels: Union[int, Tuple[int, int]],
                 out_channels: int, aggregator_type: str = 'mean', normalize: bool = False,
                 root_weight: bool = True, bias: bool = True):

        super(SAGEConv, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.normalize = normalize
        self.root_weight = root_weight
        self.aggregator_type = aggregator_type
        self.bias = bias

        assert self.aggregator_type in ['mean', 'max', 'lstm', 'gcn']

        if isinstance(in_channels, int):
            in_channels = (in_channels, in_channels)

        if self.aggregator_type == 'gcn':
          # GCN does not require self root node, but just the neighbours
          self.root_weight = False 

        if aggregator_type == 'max':
          self.lin_pool = Linear(in_channels[0], in_channels[0])

        if self.aggregator_type == 'lstm':
          self.lstm = LSTM(in_channels[0], in_channels[0], batch_first=True)

        self.lin_l = Linear(in_channels[0], out_channels, bias=bias) # neighbours
        if self.root_weight: # Not created for GCN
            self.lin_r = Linear(in_channels[1], out_channels, bias=False) # root itself

        self.reset_parameters()

    def reset_parameters(self):
        """
        Reinitialises learnable parameters
        """
        self.lin_l.reset_parameters()
        if self.root_weight:
          self.lin_r.reset_parameters()
        if self.aggregator_type == 'lstm':
          self.lstm.reset_parameters()
        if self.aggregator_type == 'max':
          self.lin_pool.reset_parameters()

    def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
                size: Size = None) -> Tensor:
        """
        Computes GraphSAGE Layer
        """

        if isinstance(x, Tensor):
          x: OptPairTensor = (x, x)

        out = self.propagate(edge_index, x=x, size=size)
        out = self.lin_l(out)

        x_r = x[1] # x[1] -- root
        if self.root_weight and x_r is not None: 
            out += self.lin_r(x_r) # root doesn't get added for GCN

        if self.normalize:
            out = F.normalize(out, p=2., dim=-1)

        return out

    def message(self, x_j: Tensor) -> Tensor:
        return x_j

    def message_and_aggregate(self, adj_t: SparseTensor,
                              x: OptPairTensor, edge_index_j, edge_index_i) -> Tensor:
      """
        Performs message Passing and aggregates messages from neighbours using the aggregator_type
      """
      adj_t = adj_t.set_value(None, layout=None)
      if self.aggregator_type == 'mean' or self.aggregator_type == 'gcn':
        return matmul(adj_t, x[0], reduce='mean')

      elif self.aggregator_type == 'max':
        return matmul(adj_t, x[0], reduce='max')
        # return torch.stack(x[0], dim=-1).max(dim=-1)[0] - alternative implementation of max pool operation

      elif self.aggregator_type == 'lstm':
        x_j = x[0][edge_index_j]
        x, mask = to_dense_batch(x_j, edge_index_i)
        return self.lstm(x)
rusty1s commented 2 years ago

Super, glad that it is running now. Feel free to contribute your solution :)

hunarbatra commented 2 years ago

Thank you for your guidance @rusty1s :). I've opened a PR with the addition of LSTM, Max Pool and GCN aggregation for SAGEConv - #4379
Please let me know if you feel any changes are required.

Extensions

if self.aggregator_type == 'bilstm': x = torch.stack(x, dim=1) # [num_nodes, num_layers, numchannels] alpha, = self.bilstm(x) alpha = self.att(alpha).squeeze(-1) # [num_nodes, num_layers] alpha = torch.softmax(alpha, dim=-1) return (x * alpha.unsqueeze(-1)).sum(dim=1)```

rusty1s commented 2 years ago

Thank you for the PR. Left some comments. Let's first try to integrate LSTM-style aggregation. As Bi-LSTM aggregation is not proposed in GraphSAGE, I am not yet sure whether SAGEConv is the best place to integrate this.

rusty1s commented 2 years ago

Integrated via https://github.com/pyg-team/pytorch_geometric/pull/4379

francyya commented 1 year ago

Not yet, and it has very low priority since LSTM aggregation isn't that useful for most GNNs. Nonetheless, you can easily implement this by yourself by overriding the aggregate function in MessagePassing. Given that your edge_indices are sorted column-wise, you can do:

def aggregate(self, inputs, index):
    out, mask = to_dense_batch(inputs, index)  # [num_nodes, num_neighbors, num_features]
    return self.lstm(out)

If the edges need to be sorted for SAGEConv, then undirected heterogenous graph might not be capable of using lstm aggregator since one pair of nodes that is sorted in one direction is necessarily not sorted in other direction. Does lstm only work for directed graph?

rusty1s commented 1 year ago

In PyG, every graph is directed, and undirected graphs are modeled by adding bidirectional edges to your graph. As such, you can always ensure that edges are sorted in both directions.

nhewadehigah commented 1 year ago

Do we need to sort the edge_index column-wise or row-wise for lstm aggregator?

rusty1s commented 1 year ago

By column. In master, there is also now an option to do this more easily:


data = data.sort(sort_by_row=False)