HongyangGao / Graph-U-Nets

Pytorch implementation of Graph U-Nets (ICML19)
http://proceedings.mlr.press/v97/gao19a/gao19a.pdf
GNU General Public License v3.0
513 stars 100 forks source link

Variational Auto-Encoder ? #15

Closed ahariri13 closed 4 years ago

ahariri13 commented 4 years ago

I am working on a Graph Variational Auto-Encoder to reproduce node features. I was wondering whether a modification of graph UNets could do so. Could I get your opinion on the following code? It's quite similar the only difference is that I split them into encoder/decoder functions, added two outputs for the encoder (mu and sigma) and used the KL Divergence in the loss function

class VarGraphUNet(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, depth,
                 pool_ratios, sum_res=True, act=F.tanh):
        super(VarGraphUNet, self).__init__()
        assert depth >= 1
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels
        self.depth = depth
        self.pool_ratios = repeat(pool_ratios, depth)
        self.act = act
        self.sum_res = sum_res

        channels = hidden_channels

        self.down_convs = torch.nn.ModuleList()
        self.pools = torch.nn.ModuleList()
        self.down_convs.append(GCNConv(in_channels, int(channels), improved=True))
        for i in range(depth):
            if i<depth-1: 
                self.pools.append(TopKPooling(channels, self.pool_ratios[i]))
                self.down_convs.append(GCNConv(channels, int(channels*2), improved=True))
                channels=int(channels*2)
            else:
              self.pools.append(TopKPooling(channels, self.pool_ratios[i]))
              self.down_convs.append(GCNConv(channels, int(channels), improved=True))

        in_channels = channels if sum_res else 2 * channels

        self.up_convs = torch.nn.ModuleList()
        for i in range(depth - 1):
            self.up_convs.append(GCNConv(channels, int(channels/2), improved=True))
            channels=int(channels/2)
        self.up_convs.append(GCNConv(channels, out_channels, improved=True))

        #self.reset_parameters()

        self.muLay=torch.nn.Linear(channels, int(channels*2))
        self.sigLay=torch.nn.Linear(channels, int(channels*2)

        self.dec=torch.nn.Linear(int(channels*2),channels)
        self.drop=torch.nn.Dropout(p=0.3)

    def encode(self,x,edge_index,batch): 

        if batch is None:
            batch = edge_index.new_zeros(x.size(0))
        edge_weight = x.new_ones(edge_index.size(1))

        x = self.down_convs[0](x, edge_index, edge_weight)
        x = self.act(x)

        xs = [x]
        edge_indices = [edge_index]
        edge_weights = [edge_weight]
        perms = []

        for i in range(1, self.depth + 1): 
            edge_index, edge_weight = self.augment_adj(edge_index, edge_weight,x.size(0))

            x, edge_index, edge_weight, batch, perm, _ = self.pools[i - 1](
                x, edge_index, edge_weight, batch)

            x=self.drop(x)
            x = self.down_convs[i](x, edge_index, edge_weight)
            x = self.act(x)

            if i < self.depth: ## < 3
                xs += [x]
                edge_indices += [edge_index]
                edge_weights += [edge_weight]
            perms += [perm]

        mu2=self.muLay(x)
        sig2=self.sigLay(x)

        return mu2,sig2,xs,edge_indices,edge_weights,perms

    def reparametrize(self, mu, logvar):
      if self.training:
          return mu + torch.randn_like(logvar) * torch.exp(logvar)
      else:
          return mu  

    def decode(self,z,xs,edge_indices,edge_weights,perms):

        z=self.dec(z)
        for i in range(self.depth):
            j = self.depth - 1 - i

            res = xs[j]

            edge_index = edge_indices[j]
            edge_weight = edge_weights[j]
            perm = perms[j]

            up = torch.zeros_like(res)
            up[perm] = z
            z = res + up if self.sum_res else torch.cat((res, up), dim=-1)
            #print(z.shape)
            z = self.up_convs[i](z, edge_index, edge_weight)
            z = self.act(z) if i < self.depth - 1 else z
        return z

    def augment_adj(self, edge_index, edge_weight, num_nodes):

        edge_index, edge_weight = add_self_loops(edge_index, edge_weight,
                                                 num_nodes=num_nodes)

        edge_index, edge_weight = sort_edge_index(edge_index, edge_weight,num_nodes)

        edge_index, edge_weight = spspmm(edge_index, edge_weight, edge_index,
                                         edge_weight, num_nodes, num_nodes,num_nodes)

        edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)
        return edge_index, edge_weight      

    def forward(self,x,adj,lengs):

        mu2,sig2,xs,edge_indices,edge_weights,perms= self.encode(x,adj,lengs)  
        z = self.reparametrize(mu2,sig2) ## z = mu + eps*sigma 
        z2=self.decode(z,xs,edge_indices,edge_weights,perms)
        return z2, mu2, sig2  
HongyangGao commented 4 years ago

Hi there,

I saw you use skip-connection in graph VAE. I am not sure if skip-connection is available when you generate graphs using your decoder.