Closed sbonner0 closed 2 years ago
Max-Pool is indeed possible, but we do not have an implementation for it. Feel free to add it if you like. A fast LSTM aggregation is currently possible, but I am working on it.
Hi, thanks so much for the reply. I would love to add it if I am able. Would it be possibe to implemet it how it is in the paper (with a second paramter matrix W_pool) using the current SAGEConv layer? Or would I need to implement it from scratch using MessagePassing? Would you have any tips on where to start?
We should maybe first come up with a separate implementation and then decide how to merge the different variants. I suggest to copy most of the SAGEConv
layer as a starting point. And yes, you need to introduce W_pool
for the max aggregation.
Thanks - I'll hopefully get time to look at this over the next week.
Hey guys, is there an update on the two aggregator functions ? Would be keen to know.
Hey - I am still planning to have a go at this as soon as possible, I was inundated with work since the lock down. If you have any idea, or would like to help, I would be most grateful.
It looks like maybe you just have to add an additional function here maybe called scatter_maxpool
or something?
torch-scatter
does already support max
reduction :)
Hi, Is there any update on max Pool aggregator? I would love to know. Also I think, can we not just change the aggregator value by letting pass an extra argument in SAGEConv constructor of torch_geometric/nn/conv/sage_conv.py ? Is there anything else that needs to be done? The aggregator value there is hard-coded as 'add'. Thanks
Update 1: Okay, somehow my jupyter notebook still shows the previous implementation of Sage_conv.py, I checked the new code and see that the custom aggregator has been added, how are we supposed to pass it as an argument? just like aggr = 'max'
right? Error reports
TypeError: init() got multiple values for keyword argument 'aggr', Line 36 sage_conv.py
Update 2: Okay, I think I need to reinstall torch_geometric in python.
Update 3: Okay, I am Unable to install the latest code! What to do????? I have uninstall and reinstall all binaries and the pytorch_geometric repo(with --no-cache-dir flag) ?? It says PyG 1.6.1 installed, but I still can't see this update.
HELP!!!!!! :/
Very Very Thanks
You can install from source:
pip install git+https://github.com/rusty1s/pytorch_geometric.git
or change the aggr
flag after initialization:
conv = SAGEConv(...)
conv.aggr = 'max'
@rusty1s
The max pooling suggested by the paper is the following:
conv = SAGEConv(...)
conv.aggr = 'max'
However, this way you mentioned seems like reducing the dimension by max before the affine operation in the formulation above.
If I would like to do the way described in the paper, should I set conv.aggr = None
and use feature-wise max
operator?
You can simply apply self.lin_l
before self.propagate
.
Hi, I am wondering if there are any updates regarding the LSTM aggregator for SAGEConv
. Thanks a lot.
Not yet, and it has very low priority since LSTM
aggregation isn't that useful for most GNNs. Nonetheless, you can easily implement this by yourself by overriding the aggregate
function in MessagePassing
. Given that your edge_indices
are sorted column-wise, you can do:
def aggregate(self, inputs, index):
out, mask = to_dense_batch(inputs, index) # [num_nodes, num_neighbors, num_features]
return self.lstm(out)
For Max Pooling - would it be right to use:
out = matmul(adj_t, x[0], reduce='max')
?
(along with applying: self.lin_l before self.propagate to match the equation in the paper)
@rusty1s
This is my code below for SAGEConv. Also, LSTM isn't working for me. Please could you tell me how to implement it or how to fix it in my implementation. Thank you so much! :)
class SAGETest2(MessagePassing):
def __init__(self, in_channels: Union[int, Tuple[int, int]],
out_channels: int, aggregator_type: str, normalize: bool = False,
root_weight: bool = True, bias: bool = True):
# kwargs.setdefault('aggr', 'lstm')
super(SAGETest2, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.normalize = normalize
self.root_weight = root_weight
self.aggregator_type = aggregator_type
self.bias = bias
if isinstance(in_channels, int):
in_channels = (in_channels, in_channels)
if self.aggregator_type == 'gcn':
# GCN does not require self root node, but just the neighbours
self.root_weight = False
if aggregator_type == 'max':
self.lin_pool = Linear(in_channels[0], in_channels[0])
if self.aggregator_type == 'lstm':
self.lstm = LSTM(
in_channels[0], (2 * in_channels[0]) // 2,
bidirectional=True,
batch_first=True)
self.att = Linear(2 * ((2 * in_channels[0]) // 2), 1)
# if self.aggregator_type == 'lstm':
# self.lstm = nn.LSTM(
# input_size=in_channels,
# hidden_size=out_channels,
# batch_first=True,
# )
if self.aggregator_type != 'lstm':
self.lin_l = Linear(in_channels[0], out_channels, bias=bias) # neighbours
else:
self.lin_l = Linear(in_channels[0] + (2 * in_channels[0]), out_channels, bias=bias) # neighbours
if self.root_weight: # Not created for GCN
self.lin_r = Linear(in_channels[1], out_channels, bias=False) # root itself
self.reset_parameters()
def reset_parameters(self):
self.lin_l.reset_parameters()
if self.root_weight:
self.lin_r.reset_parameters()
if self.aggregator_type == 'lstm':
self.lstm.reset_parameters()
if self.aggregator_type == 'max':
self.lin_pool.reset_parameters()
def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
size: Size = None) -> Tensor:
""""""
if isinstance(x, Tensor):
x: OptPairTensor = (x, x)
# Determine whether to apply linear transformation before message passing A(XW)
# in_channels = self.in_channels
# lin_before_mp = in_channels[0] > self.out_channels
# if lin_before_mp:
# # Apply linear transformation before message passing to match the eqn in the paper
# out = self.lin_l(x[0]) # lin_l is for neighbours, and x[0] is the source i.e neighbours
# # propagate_type: (x: OptPairTensor)
# out = self.propagate(edge_index, x=x, size=size)
# else:
# # propagate_type: (x: OptPairTensor)
# out = self.propagate(edge_index, x=x, size=size)
# out = self.lin_l(out)
if self.aggregator_type == 'max':
out = F.relu(self.lin_pool(x[0]))
out = self.propagate(edge_index, x=x, size=size)
out = self.lin_l(out)
else:
out = self.propagate(edge_index, x=x, size=size)
out = self.lin_l(out)
x_r = x[1] # x[1] -- root
if self.root_weight and x_r is not None:
out += self.lin_r(x_r) # root doesn't get added for gcn
# bias term
# if self.bias is not None:
# out = out + self.bias
if self.normalize:
out = F.normalize(out, p=2., dim=-1)
return out
def message(self, x_j: Tensor) -> Tensor:
return x_j
# def aggregate(self, inputs, index):
# if self.aggregator_type == 'lstm':
# out, mask = to_dense_batch(inputs[0], index) # [num_nodes, num_neighbors, num_features]
# return self.lstm(out)
def message_and_aggregate(self, adj_t: SparseTensor,
x: OptPairTensor) -> Tensor:
adj_t = adj_t.set_value(None, layout=None)
if self.aggregator_type == 'mean' or self.aggregator_type == 'gcn':
out = matmul(adj_t, x[0], reduce='mean')
elif self.aggregator_type == 'max':
# out = matmul(adj_t, x[0], reduce='max')
out = torch.stack(x[0], dim=-1).max(dim=-1)[0]
elif self.aggregator_type == 'lstm':
# x, (h_n, c_n) = self.lstm(x[0])
# out = h_n
x = torch.stack(x[0], dim=1) # [num_nodes, num_layers, num_channels]
alpha, _ = self.lstm(x)
alpha = self.att(alpha).squeeze(-1) # [num_nodes, num_layers]
alpha = torch.softmax(alpha, dim=-1)
return (x * alpha.unsqueeze(-1)).sum(dim=1)
return out
max
aggregation should be easily achievable by just setting aggr="max"
when instantiating SAGEConv
.
Your LSTM implementation looks a bit confusing to me since you are missing the edge-level x_j
matrix computation:
def message_and_aggregate(self, adj_t, x, edge_index_j, edge_index_i):
x_j = x[0][edge_index_j]
x, mask = to_dense_batch(x_j, edge_index_i)
x = self.lstm(x)
...
Thank you so much! Finally, LSTM is working for me and this is my final SAGEConv: Do you think I should open a PR as it might help others looking for these implementations?
from typing import Tuple, Union
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import LSTM
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.typing import Adj, OptPairTensor, Size
from torch_sparse import SparseTensor, matmul
from torch_geometric.utils import to_dense_batch
class SAGEConv(MessagePassing):
def __init__(self, in_channels: Union[int, Tuple[int, int]],
out_channels: int, aggregator_type: str = 'mean', normalize: bool = False,
root_weight: bool = True, bias: bool = True):
super(SAGEConv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.normalize = normalize
self.root_weight = root_weight
self.aggregator_type = aggregator_type
self.bias = bias
assert self.aggregator_type in ['mean', 'max', 'lstm', 'gcn']
if isinstance(in_channels, int):
in_channels = (in_channels, in_channels)
if self.aggregator_type == 'gcn':
# GCN does not require self root node, but just the neighbours
self.root_weight = False
if aggregator_type == 'max':
self.lin_pool = Linear(in_channels[0], in_channels[0])
if self.aggregator_type == 'lstm':
self.lstm = LSTM(in_channels[0], in_channels[0], batch_first=True)
self.lin_l = Linear(in_channels[0], out_channels, bias=bias) # neighbours
if self.root_weight: # Not created for GCN
self.lin_r = Linear(in_channels[1], out_channels, bias=False) # root itself
self.reset_parameters()
def reset_parameters(self):
"""
Reinitialises learnable parameters
"""
self.lin_l.reset_parameters()
if self.root_weight:
self.lin_r.reset_parameters()
if self.aggregator_type == 'lstm':
self.lstm.reset_parameters()
if self.aggregator_type == 'max':
self.lin_pool.reset_parameters()
def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
size: Size = None) -> Tensor:
"""
Computes GraphSAGE Layer
"""
if isinstance(x, Tensor):
x: OptPairTensor = (x, x)
out = self.propagate(edge_index, x=x, size=size)
out = self.lin_l(out)
x_r = x[1] # x[1] -- root
if self.root_weight and x_r is not None:
out += self.lin_r(x_r) # root doesn't get added for GCN
if self.normalize:
out = F.normalize(out, p=2., dim=-1)
return out
def message(self, x_j: Tensor) -> Tensor:
return x_j
def message_and_aggregate(self, adj_t: SparseTensor,
x: OptPairTensor, edge_index_j, edge_index_i) -> Tensor:
"""
Performs message Passing and aggregates messages from neighbours using the aggregator_type
"""
adj_t = adj_t.set_value(None, layout=None)
if self.aggregator_type == 'mean' or self.aggregator_type == 'gcn':
return matmul(adj_t, x[0], reduce='mean')
elif self.aggregator_type == 'max':
return matmul(adj_t, x[0], reduce='max')
# return torch.stack(x[0], dim=-1).max(dim=-1)[0] - alternative implementation of max pool operation
elif self.aggregator_type == 'lstm':
x_j = x[0][edge_index_j]
x, mask = to_dense_batch(x_j, edge_index_i)
return self.lstm(x)
Super, glad that it is running now. Feel free to contribute your solution :)
Thank you for your guidance @rusty1s :). I've opened a PR with the addition of LSTM, Max Pool and GCN aggregation for SAGEConv - #4379
Please let me know if you feel any changes are required.
BiLSTM with attention aggregator
for SAGEConv inspired from the Jumping Knowledge layer aggregation module from the paper: Representation Learning on Graphs with Jumping Knowledge Networks. Do you think I should add that as well?
Implementation:
if self.aggregator_type == 'bilstm':
self.bilstm = LSTM(in_channels[0], in_channels[0]//2, bidirectional=True, batch_first=True)
self.att = Linear(2 * in_channels[0], 1)
...
# in message_and_aggregate()
...
if self.aggregator_type == 'bilstm': x = torch.stack(x, dim=1) # [num_nodes, num_layers, numchannels] alpha, = self.bilstm(x) alpha = self.att(alpha).squeeze(-1) # [num_nodes, num_layers] alpha = torch.softmax(alpha, dim=-1) return (x * alpha.unsqueeze(-1)).sum(dim=1)```
Thank you for the PR. Left some comments. Let's first try to integrate LSTM-style aggregation. As Bi-LSTM aggregation is not proposed in GraphSAGE, I am not yet sure whether SAGEConv
is the best place to integrate this.
Integrated via https://github.com/pyg-team/pytorch_geometric/pull/4379
Not yet, and it has very low priority since
LSTM
aggregation isn't that useful for most GNNs. Nonetheless, you can easily implement this by yourself by overriding theaggregate
function inMessagePassing
. Given that youredge_indices
are sorted column-wise, you can do:def aggregate(self, inputs, index): out, mask = to_dense_batch(inputs, index) # [num_nodes, num_neighbors, num_features] return self.lstm(out)
If the edges need to be sorted for SAGEConv, then undirected heterogenous graph might not be capable of using lstm aggregator since one pair of nodes that is sorted in one direction is necessarily not sorted in other direction. Does lstm only work for directed graph?
In PyG, every graph is directed, and undirected graphs are modeled by adding bidirectional edges to your graph. As such, you can always ensure that edges are sorted in both directions.
Do we need to sort the edge_index column-wise or row-wise for lstm aggregator?
By column. In master, there is also now an option to do this more easily:
data = data.sort(sort_by_row=False)
Thanks for all the great work on the package and sorry if this has been asked before.
I was wondering if it possible to use different aggregation functions, like Max-Pool or LSTM aggregation from the original GraphSAGE paper, with the existing SAGEConv layer? Looking at the documentation, it seems that only the mean aggregation is available.