Closed jwaladhamala closed 4 years ago
What's the suggestion?
Ironically, I'm just working on it!
FYI, I'm still working on it, but haven't yet achieved the results reported in the paper.
Any luck with unpooling operators? It would be neat to be able to use the clustering data from downsample to upsample data again. A similar approach to the Graph U-Net unpooling operator was implemented in Ranjan et al. (Section 3.2): https://ps.is.tuebingen.mpg.de/uploads_file/attachment/attachment/439/1285.pdf
Hi, unpooling operations of this kind are really trivial with the means of PyTorch and PyG. Simply use the cluster
vector and, e.g., gather node information:
x_unpooled = x[cluster]
It is just that we do not offer an explicit operator for these kind of operations.
Hi @rusty1s any luck replicating the results? Do you have a working branch we can look at? Thanks for the amazing library btw
Sadly no, it seems like we need to wait for an official implementation. We could open a WIP pull request though.
That would be nice, I'll try replicating the results in mid-august if no one did it by then :) Any WIP code would be useful
BTW I see that in the top k pooling you don't augment the connectivity with the graph power. This seems to be very important in their results. Might be worth adding a parameter is_augment_connectivity
?
For future reference: https://github.com/HongyangGao/gunet (I had not realized there was an official implementation)
Awesome, thank you!
Other question : in #355 you seem to suggest using two hop, but that only augments the adjacecy matrix without changing the weights right ? In that case it's different from the Unet paper where they do A^2 which I think makes more sense as you keep track of which nodes were actually close. I asked because maybe they even do (A+I)^2 which would make the most sense. Any way of implementing this in pytorch geom?
Yes, the official implementation seems to not compute the graph augmentation before the graph coarsening procedure. In addition, examples on the citation graphs are missing.
Your point seems to be valid though. The TwoHop
transform does indeed only add two-hop neighbors to the graph and ignores any edge weights. We can easily implement this using add_self_loops
and torch-sparse
.
SO here is what I have for now, it is learning, but I should benchmark to see if it is as good as they say:
def forward(self, data):
x, edge_index, batch = data.x, data.edge_index, data.batch
edge_weight = x.new_ones((edge_index.size(1), ))
n_blocks = self.n_layers // 2 if self.is_double_conv else self.n_layers
n_down_blocks = n_blocks // 2
residuals = [None] * n_down_blocks
edges = [None] * n_down_blocks
perms = [None] * n_down_blocks
# Down
for i in range(n_down_blocks):
x = self._apply_conv_block_i(x, edge_index, i, edge_weight=edge_weight)
residuals[i] = x
# not clear whether to save before or after augment
edges[i] = (edge_index, edge_weight)
# (A + I)^2
edge_index, edge_weight = self.augment_adj(edge_index, edge_weight, x.size(0))
x, edge_index, edge_weight, batch, perm = self.pools[i](x, edge_index,
edge_attr=edge_weight,
batch=batch)
perms[i] = perm
# Bottleneck
x = self._apply_conv_block_i(x, edge_index, n_down_blocks, edge_weight=edge_weight)
# Up
for i in range(n_down_blocks + 1, n_blocks):
edge_index, edge_weight = edges[n_down_blocks - i]
res = residuals[n_down_blocks - i]
up = torch.zeros_like(res)
up[perms[n_down_blocks - i]] = x
x = torch.cat((res, up), dim=1) # conncat on channels
x = self._apply_conv_block_i(x, edge_index, i, edge_weight=edge_weight)
data.x, data.edge_index, data.batch = x, edge_index, batch
return data
def augment_adj(self, edge_index, edge_weight, n_nodes):
edge_index, edge_weight = torch_geometric.utils.add_self_loops(edge_index,
edge_weight=edge_weight)
edge_index, edge_weight = torch_sparse.spspmm(edge_index, edge_weight,
edge_index, edge_weight,
n_nodes, n_nodes, n_nodes)
return edge_index, edge_weight
def _apply_conv_block_i(self, x, edge_index, i, **kwargs):
"""
Apply the i^th convolution block. In usual Unet applies 2 convolutions
before down/up sampling.
"""
if self.is_double_conv:
i *= 2
x = self.activation_(self.norms[i](self.convs[i](x, edge_index, **kwargs)))
if self.is_double_conv:
x = self.activation_(self.norms[i + 1](self.convs[i + 1](x, edge_index, **kwargs)))
return x
I'm very new to graph NN and pythorch geom, @rusty1s is there anything that stands out like very wrong ?
Looks quite perfect to me :)
Two questions:
1/ In the image segmentation world, Unet always has 2 convolutions per block (usually two 3x3 convolutions). The paper uses only a single one, I wanted to give the possibility of using both for benchmarks. TBH I don't have a good intuition of why it works better for images.
2/ They say they can do both but the code shows summing, indeed. Again in the image world it's standard to concatenate. I will add a parameter to deal with both cases.
From very initial experiments, it seems that concat is better (although the number of is a lot larger so cannot compare directly) while double conv is not (maybe it's because in the graph world stacking many convolutions doesn't seem as helpful as in the image world)
I, unfortunately, won't have time to make all the benchmarks until the 15th of august. If no one did it by then, I will.
For reference here is the full code:
import torch_geometric
import torch.nn as nn
class GraphUnet(torch.nn.Module):
def __init__(self, n_channels,
Conv=torch_geometric.nn.GCNConv,
n_layers=5,
Activation=nn.ReLU,
Normalization=nn.Identity,
Pool=torch_geometric.nn.TopKPooling,
is_double_conv=False,
max_nchannels=1024, # enables bounding as channel grows exponentially
is_sum_res=True,
factor_chan=2,
**kwargs):
super().__init__()
self.activation_ = Activation(inplace=True)
self.is_double_conv = is_double_conv
self.max_nchannels = max_nchannels
self.is_sum_res = is_sum_res
self.factor_chan = factor_chan
self.in_out_channels = self._get_in_out_channels(n_channels, n_layers)
self.convs = nn.ModuleList([Conv(in_chan, out_chan, **kwargs)
for in_chan, out_chan in self.in_out_channels])
self.norms = nn.ModuleList([Normalization(out_chan)
for _, out_chan in self.in_out_channels])
pool_in_chan = [min(self.factor_chan**i * n_channels, max_nchannels)
for i in range(n_layers // 2 + 1)][1:]
self.pools = nn.ModuleList([Pool(in_chan) for in_chan in pool_in_chan])
self.reset_parameters()
def reset_parameters(self):
for conv in self.convs:
conv.reset_parameters()
def forward(self, data):
x, edge_index, batch = data.x, data.edge_index, data.batch
edge_weight = x.new_ones((edge_index.size(1), ))
n_blocks = self.n_layers // 2 if self.is_double_conv else self.n_layers
n_down_blocks = n_blocks // 2
residuals = [None] * n_down_blocks
edges = [None] * n_down_blocks
perms = [None] * n_down_blocks
# Down
for i in range(n_down_blocks):
x = self._apply_conv_block_i(x, edge_index, i, edge_weight=edge_weight)
residuals[i] = x
# not clear whether to save before or after augment
edges[i] = (edge_index, edge_weight)
# (A + I)^2
edge_index, edge_weight = self.augment_adj(edge_index, edge_weight, x.size(0))
x, edge_index, edge_weight, batch, perm = self.pools[i](x, edge_index,
edge_attr=edge_weight,
batch=batch)
perms[i] = perm
# Bottleneck
x = self._apply_conv_block_i(x, edge_index, n_down_blocks, edge_weight=edge_weight)
# Up
for i in range(n_down_blocks + 1, n_blocks):
edge_index, edge_weight = edges[n_down_blocks - i]
res = residuals[n_down_blocks - i]
up = torch.zeros_like(res)
up[perms[n_down_blocks - i]] = x
if not self.is_sum_res:
x = torch.cat((res, up), dim=1) # conncat on channels
else:
x = res + up
x = self._apply_conv_block_i(x, edge_index, i, edge_weight=edge_weight)
data.x, data.edge_index, data.batch = x, edge_index, batch
return data
def augment_adj(self, edge_index, edge_weight, n_nodes):
edge_index, edge_weight = torch_geometric.utils.add_self_loops(edge_index,
edge_weight=edge_weight)
edge_index, edge_weight = torch_sparse.spspmm(edge_index, edge_weight,
edge_index, edge_weight,
n_nodes, n_nodes, n_nodes)
return edge_index, edge_weight
def _apply_conv_block_i(self, x, edge_index, i, **kwargs):
"""Apply the i^th convolution block."""
if self.is_double_conv:
i *= 2
x = self.activation_(self.norms[i](self.convs[i](x, edge_index, **kwargs)))
if self.is_double_conv:
x = self.activation_(self.norms[i + 1](self.convs[i + 1](x, edge_index, **kwargs)))
return x
def _get_in_out_channels(self, n_channels, n_layers):
"""Return a list of tuple of input and output channels for a Unet."""
if self.is_double_conv:
assert n_layers % 2 == 0, "n_layers={} not even".format(n_layers)
# e.g. if n_channels=16, n_layers=10: [16, 32, 64]
channel_list = [self.factor_chan**i * n_channels for i in range(n_layers // 4 + 1)]
# e.g.: [16, 16, 32, 32, 64, 64]
channel_list = [i for i in channel_list for _ in (0, 1)]
# e.g.: [16, 16, 32, 32, 64, 64, 64, 32, 32, 16, 16]
channel_list = channel_list + channel_list[-2::-1]
# bound max number of channels by self.max_nchannels
channel_list = [min(c, self.max_nchannels) for c in channel_list]
# e.g.: [..., (32, 32), (32, 64), (64, 64), (64, 32), (32, 32), (32, 16) ...]
in_out_channels = self._list_chan_to_tuple(channel_list, n_layers)
if not self.is_sum_res:
# e.g.: [..., (32, 32), (32, 64), (64, 64), (128, 32), (32, 32), (64, 16) ...]
# due to concat
idcs = slice(len(in_out_channels) // 2 + 1, len(in_out_channels), 2)
in_out_channels[idcs] = [(in_chan * 2, out_chan)
for in_chan, out_chan in in_out_channels[idcs]]
else:
assert n_layers % 2 == 1, "n_layers={} not odd".format(n_layers)
# e.g. if n_channels=16, n_layers=5: [16, 32, 64]
channel_list = [self.factor_chan**i * n_channels for i in range(n_layers // 2 + 1)]
# e.g.: [16, 32, 64, 64, 32, 16]
channel_list = channel_list + channel_list[::-1]
# bound max number of channels by self.max_nchannels
channel_list = [min(c, self.max_nchannels) for c in channel_list]
# e.g.: [(16, 32), (32,64), (64, 64), (64, 32), (32, 16)]
in_out_channels = self._list_chan_to_tuple(channel_list, n_layers)
if not self.is_sum_res:
# e.g.: [(16, 32), (32,64), (64, 64), (128, 32), (64, 16)] due to concat
idcs = slice(len(in_out_channels) // 2 + 1, len(in_out_channels))
in_out_channels[idcs] = [(in_chan * 2, out_chan)
for in_chan, out_chan in in_out_channels[idcs]]
return in_out_channels
def _list_chan_to_tuple(self, n_channels, n_layers):
"""Return a list of tuple of input and output channels."""
channel_list = list(n_channels)
assert len(channel_list) == n_layers + 1
return list(zip(channel_list, channel_list[1:]))
Sorry, _get_in_out_channels
is very dirty code but I wanted it to be very general (i.e. can chose the number of layers, channels...) for benchmarks (I'm sure there are cleaner ways, I haven't thought about it too long).
Not that the model is independent of the convolution, so you should be able to use GAT, GIN, GCN ...
The paper model should be something like
from functools import partial
from torch_geometric.nn import GCNConv
GraphUnet(32, Conv=partial(GCNConv, improved=True), n_layers=7, is_sum_res=True, is_double_conv=False)
Model:
GraphUnet(
(activation_): ReLU(inplace)
(convs): ModuleList(
(0): GCNConv(32, 64)
(1): GCNConv(64, 128)
(2): GCNConv(128, 256)
(3): GCNConv(256, 256)
(4): GCNConv(256, 128)
(5): GCNConv(128, 64)
(6): GCNConv(64, 32)
)
(norms): ModuleList(
(0): Identity()
(1): Identity()
(2): Identity()
(3): Identity()
(4): Identity()
(5): Identity()
(6): Identity()
)
(pools): ModuleList(
(0): TopKPooling(64, ratio=0.5)
(1): TopKPooling(128, ratio=0.5)
(2): TopKPooling(256, ratio=0.5)
)
)
Thank you. It would be great if you could submit a PR to credit you (can be dirty, I can clean it up).
Ok I'll send the code above as a PR, do you want it in nn ? (It's a meta layer / network rather than a layer so not sure where it should go)
You can put it to nn.models
Sorry for incomplete suggestion. Is it possible to include gUnpool (graph un-pooling layer) described in graph U-Net paper (https://openreview.net/pdf?id=HJePRoAct7)