pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
20.59k stars 3.57k forks source link

Graph U-Net Implementation #37

Closed jwaladhamala closed 4 years ago

jwaladhamala commented 5 years ago

Sorry for incomplete suggestion. Is it possible to include gUnpool (graph un-pooling layer) described in graph U-Net paper (https://openreview.net/pdf?id=HJePRoAct7)

rusty1s commented 5 years ago

What's the suggestion?

rusty1s commented 5 years ago

Ironically, I'm just working on it!

rusty1s commented 5 years ago

FYI, I'm still working on it, but haven't yet achieved the results reported in the paper.

Caykroyd commented 5 years ago

Any luck with unpooling operators? It would be neat to be able to use the clustering data from downsample to upsample data again. A similar approach to the Graph U-Net unpooling operator was implemented in Ranjan et al. (Section 3.2): https://ps.is.tuebingen.mpg.de/uploads_file/attachment/attachment/439/1285.pdf

rusty1s commented 5 years ago

Hi, unpooling operations of this kind are really trivial with the means of PyTorch and PyG. Simply use the cluster vector and, e.g., gather node information:

x_unpooled = x[cluster]

It is just that we do not offer an explicit operator for these kind of operations.

YannDubs commented 4 years ago

Hi @rusty1s any luck replicating the results? Do you have a working branch we can look at? Thanks for the amazing library btw

rusty1s commented 4 years ago

Sadly no, it seems like we need to wait for an official implementation. We could open a WIP pull request though.

YannDubs commented 4 years ago

That would be nice, I'll try replicating the results in mid-august if no one did it by then :) Any WIP code would be useful

YannDubs commented 4 years ago

BTW I see that in the top k pooling you don't augment the connectivity with the graph power. This seems to be very important in their results. Might be worth adding a parameter is_augment_connectivity ?

YannDubs commented 4 years ago

For future reference: https://github.com/HongyangGao/gunet (I had not realized there was an official implementation)

rusty1s commented 4 years ago

Awesome, thank you!

YannDubs commented 4 years ago

Other question : in #355 you seem to suggest using two hop, but that only augments the adjacecy matrix without changing the weights right ? In that case it's different from the Unet paper where they do A^2 which I think makes more sense as you keep track of which nodes were actually close. I asked because maybe they even do (A+I)^2 which would make the most sense. Any way of implementing this in pytorch geom?

rusty1s commented 4 years ago

Yes, the official implementation seems to not compute the graph augmentation before the graph coarsening procedure. In addition, examples on the citation graphs are missing.

Your point seems to be valid though. The TwoHop transform does indeed only add two-hop neighbors to the graph and ignores any edge weights. We can easily implement this using add_self_loops and torch-sparse.

YannDubs commented 4 years ago

SO here is what I have for now, it is learning, but I should benchmark to see if it is as good as they say:

def forward(self, data):
    x, edge_index, batch = data.x, data.edge_index, data.batch
    edge_weight = x.new_ones((edge_index.size(1), ))

    n_blocks = self.n_layers // 2 if self.is_double_conv else self.n_layers
    n_down_blocks = n_blocks // 2
    residuals = [None] * n_down_blocks
    edges = [None] * n_down_blocks
    perms = [None] * n_down_blocks

    # Down
    for i in range(n_down_blocks):
        x = self._apply_conv_block_i(x, edge_index, i, edge_weight=edge_weight)
        residuals[i] = x

        # not clear whether to save before or after augment
        edges[i] = (edge_index, edge_weight)

        # (A + I)^2
        edge_index, edge_weight = self.augment_adj(edge_index, edge_weight, x.size(0))

        x, edge_index, edge_weight, batch, perm = self.pools[i](x, edge_index,
                                                                edge_attr=edge_weight,
                                                                batch=batch)
        perms[i] = perm

    # Bottleneck
    x = self._apply_conv_block_i(x, edge_index, n_down_blocks, edge_weight=edge_weight)

    # Up
    for i in range(n_down_blocks + 1, n_blocks):
        edge_index, edge_weight = edges[n_down_blocks - i]
        res = residuals[n_down_blocks - i]
        up = torch.zeros_like(res)
        up[perms[n_down_blocks - i]] = x
        x = torch.cat((res, up), dim=1)  # conncat on channels
        x = self._apply_conv_block_i(x, edge_index, i, edge_weight=edge_weight)

    data.x, data.edge_index, data.batch = x, edge_index, batch

    return data

def augment_adj(self, edge_index, edge_weight, n_nodes):
    edge_index, edge_weight = torch_geometric.utils.add_self_loops(edge_index,
                                                                   edge_weight=edge_weight)
    edge_index, edge_weight = torch_sparse.spspmm(edge_index, edge_weight,
                                                  edge_index, edge_weight,
                                                  n_nodes, n_nodes, n_nodes)
    return edge_index, edge_weight

def _apply_conv_block_i(self, x, edge_index, i, **kwargs):
    """
    Apply the i^th convolution block. In usual Unet applies 2 convolutions 
    before down/up sampling.
    """
    if self.is_double_conv:
        i *= 2

    x = self.activation_(self.norms[i](self.convs[i](x, edge_index, **kwargs)))

    if self.is_double_conv:
        x = self.activation_(self.norms[i + 1](self.convs[i + 1](x, edge_index, **kwargs)))

    return x

I'm very new to graph NN and pythorch geom, @rusty1s is there anything that stands out like very wrong ?

rusty1s commented 4 years ago

Looks quite perfect to me :)

rusty1s commented 4 years ago

Two questions:

  1. What is the intuition behind the double convs? Is this really needed?
  2. i think the paper states that skip connections are summed up instead of comcatenated.
YannDubs commented 4 years ago

1/ In the image segmentation world, Unet always has 2 convolutions per block (usually two 3x3 convolutions). The paper uses only a single one, I wanted to give the possibility of using both for benchmarks. TBH I don't have a good intuition of why it works better for images.

2/ They say they can do both but the code shows summing, indeed. Again in the image world it's standard to concatenate. I will add a parameter to deal with both cases.

YannDubs commented 4 years ago

From very initial experiments, it seems that concat is better (although the number of is a lot larger so cannot compare directly) while double conv is not (maybe it's because in the graph world stacking many convolutions doesn't seem as helpful as in the image world)

YannDubs commented 4 years ago

I, unfortunately, won't have time to make all the benchmarks until the 15th of august. If no one did it by then, I will.

For reference here is the full code:

import torch_geometric
import torch.nn as nn

class GraphUnet(torch.nn.Module):
    def __init__(self, n_channels,
                 Conv=torch_geometric.nn.GCNConv,
                 n_layers=5,
                 Activation=nn.ReLU,
                 Normalization=nn.Identity,
                 Pool=torch_geometric.nn.TopKPooling,
                 is_double_conv=False,
                 max_nchannels=1024, # enables bounding as channel grows exponentially
                 is_sum_res=True,
                 factor_chan=2,
                 **kwargs):
        super().__init__()

        self.activation_ = Activation(inplace=True)
        self.is_double_conv = is_double_conv
        self.max_nchannels = max_nchannels
        self.is_sum_res = is_sum_res
        self.factor_chan = factor_chan

        self.in_out_channels = self._get_in_out_channels(n_channels, n_layers)
        self.convs = nn.ModuleList([Conv(in_chan, out_chan, **kwargs)
                                    for in_chan, out_chan in self.in_out_channels])

        self.norms = nn.ModuleList([Normalization(out_chan)
                                    for _, out_chan in self.in_out_channels])

        pool_in_chan = [min(self.factor_chan**i * n_channels, max_nchannels)
                        for i in range(n_layers // 2 + 1)][1:]
        self.pools = nn.ModuleList([Pool(in_chan) for in_chan in pool_in_chan])

        self.reset_parameters()

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()

    def forward(self, data):

        x, edge_index, batch = data.x, data.edge_index, data.batch
        edge_weight = x.new_ones((edge_index.size(1), ))

        n_blocks = self.n_layers // 2 if self.is_double_conv else self.n_layers
        n_down_blocks = n_blocks // 2
        residuals = [None] * n_down_blocks
        edges = [None] * n_down_blocks
        perms = [None] * n_down_blocks

        # Down
        for i in range(n_down_blocks):
            x = self._apply_conv_block_i(x, edge_index, i, edge_weight=edge_weight)
            residuals[i] = x

            # not clear whether to save before or after augment
            edges[i] = (edge_index, edge_weight)

            # (A + I)^2
            edge_index, edge_weight = self.augment_adj(edge_index, edge_weight, x.size(0))

            x, edge_index, edge_weight, batch, perm = self.pools[i](x, edge_index,
                                                                    edge_attr=edge_weight,
                                                                    batch=batch)
            perms[i] = perm

        # Bottleneck
        x = self._apply_conv_block_i(x, edge_index, n_down_blocks, edge_weight=edge_weight)

        # Up
        for i in range(n_down_blocks + 1, n_blocks):
            edge_index, edge_weight = edges[n_down_blocks - i]
            res = residuals[n_down_blocks - i]
            up = torch.zeros_like(res)
            up[perms[n_down_blocks - i]] = x
            if not self.is_sum_res:
                x = torch.cat((res, up), dim=1)  # conncat on channels
            else:
                x = res + up
            x = self._apply_conv_block_i(x, edge_index, i, edge_weight=edge_weight)

        data.x, data.edge_index, data.batch = x, edge_index, batch

        return data

    def augment_adj(self, edge_index, edge_weight, n_nodes):
        edge_index, edge_weight = torch_geometric.utils.add_self_loops(edge_index,
                                                                       edge_weight=edge_weight)
        edge_index, edge_weight = torch_sparse.spspmm(edge_index, edge_weight,
                                                      edge_index, edge_weight,
                                                      n_nodes, n_nodes, n_nodes)
        return edge_index, edge_weight

    def _apply_conv_block_i(self, x, edge_index, i, **kwargs):
        """Apply the i^th convolution block."""
        if self.is_double_conv:
            i *= 2

        x = self.activation_(self.norms[i](self.convs[i](x, edge_index, **kwargs)))

        if self.is_double_conv:
            x = self.activation_(self.norms[i + 1](self.convs[i + 1](x, edge_index, **kwargs)))

        return x

    def _get_in_out_channels(self, n_channels, n_layers):
        """Return a list of tuple of input and output channels for a Unet."""

        if self.is_double_conv:
            assert n_layers % 2 == 0, "n_layers={} not even".format(n_layers)
            # e.g. if n_channels=16, n_layers=10: [16, 32, 64]
            channel_list = [self.factor_chan**i * n_channels for i in range(n_layers // 4 + 1)]
            # e.g.: [16, 16, 32, 32, 64, 64]
            channel_list = [i for i in channel_list for _ in (0, 1)]
            # e.g.: [16, 16, 32, 32, 64, 64, 64, 32, 32, 16, 16]
            channel_list = channel_list + channel_list[-2::-1]
            # bound max number of channels by self.max_nchannels
            channel_list = [min(c, self.max_nchannels) for c in channel_list]
            # e.g.: [..., (32, 32), (32, 64), (64, 64), (64, 32), (32, 32), (32, 16) ...]
            in_out_channels = self._list_chan_to_tuple(channel_list, n_layers)
            if not self.is_sum_res:
                # e.g.: [..., (32, 32), (32, 64), (64, 64), (128, 32), (32, 32), (64, 16) ...]
                # due to concat
                idcs = slice(len(in_out_channels) // 2 + 1, len(in_out_channels), 2)
                in_out_channels[idcs] = [(in_chan * 2, out_chan)
                                         for in_chan, out_chan in in_out_channels[idcs]]
        else:
            assert n_layers % 2 == 1, "n_layers={} not odd".format(n_layers)
            # e.g. if n_channels=16, n_layers=5: [16, 32, 64]
            channel_list = [self.factor_chan**i * n_channels for i in range(n_layers // 2 + 1)]
            # e.g.: [16, 32, 64, 64, 32, 16]
            channel_list = channel_list + channel_list[::-1]
            # bound max number of channels by self.max_nchannels
            channel_list = [min(c, self.max_nchannels) for c in channel_list]
            # e.g.: [(16, 32), (32,64), (64, 64), (64, 32), (32, 16)]
            in_out_channels = self._list_chan_to_tuple(channel_list, n_layers)
            if not self.is_sum_res:
                # e.g.: [(16, 32), (32,64), (64, 64), (128, 32), (64, 16)] due to concat
                idcs = slice(len(in_out_channels) // 2 + 1, len(in_out_channels))
                in_out_channels[idcs] = [(in_chan * 2, out_chan)
                                         for in_chan, out_chan in in_out_channels[idcs]]

        return in_out_channels

    def _list_chan_to_tuple(self, n_channels, n_layers):
        """Return a list of tuple of input and output channels."""
        channel_list = list(n_channels)
        assert len(channel_list) == n_layers + 1
        return list(zip(channel_list, channel_list[1:]))

Sorry, _get_in_out_channels is very dirty code but I wanted it to be very general (i.e. can chose the number of layers, channels...) for benchmarks (I'm sure there are cleaner ways, I haven't thought about it too long).

Not that the model is independent of the convolution, so you should be able to use GAT, GIN, GCN ...

The paper model should be something like

from functools import partial
from torch_geometric.nn import GCNConv

GraphUnet(32, Conv=partial(GCNConv, improved=True), n_layers=7, is_sum_res=True, is_double_conv=False)

Model:

GraphUnet(
  (activation_): ReLU(inplace)
  (convs): ModuleList(
    (0): GCNConv(32, 64)
    (1): GCNConv(64, 128)
    (2): GCNConv(128, 256)
    (3): GCNConv(256, 256)
    (4): GCNConv(256, 128)
    (5): GCNConv(128, 64)
    (6): GCNConv(64, 32)
  )
  (norms): ModuleList(
    (0): Identity()
    (1): Identity()
    (2): Identity()
    (3): Identity()
    (4): Identity()
    (5): Identity()
    (6): Identity()
  )
  (pools): ModuleList(
    (0): TopKPooling(64, ratio=0.5)
    (1): TopKPooling(128, ratio=0.5)
    (2): TopKPooling(256, ratio=0.5)
  )
)
rusty1s commented 4 years ago

Thank you. It would be great if you could submit a PR to credit you (can be dirty, I can clean it up).

YannDubs commented 4 years ago

Ok I'll send the code above as a PR, do you want it in nn ? (It's a meta layer / network rather than a layer so not sure where it should go)

rusty1s commented 4 years ago

You can put it to nn.models