I am working on a Graph Variational Auto-Encoder to reproduce node features. I was wondering whether a modification of graph UNets could do so. Could I get your opinion on the following code? It's quite similar the only difference is that I split them into encoder/decoder functions, added two outputs for the encoder (mu and sigma) and used the KL Divergence in the loss function
class VarGraphUNet(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, depth,
pool_ratios, sum_res=True, act=F.tanh):
super(VarGraphUNet, self).__init__()
assert depth >= 1
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.out_channels = out_channels
self.depth = depth
self.pool_ratios = repeat(pool_ratios, depth)
self.act = act
self.sum_res = sum_res
channels = hidden_channels
self.down_convs = torch.nn.ModuleList()
self.pools = torch.nn.ModuleList()
self.down_convs.append(GCNConv(in_channels, int(channels), improved=True))
for i in range(depth):
if i<depth-1:
self.pools.append(TopKPooling(channels, self.pool_ratios[i]))
self.down_convs.append(GCNConv(channels, int(channels*2), improved=True))
channels=int(channels*2)
else:
self.pools.append(TopKPooling(channels, self.pool_ratios[i]))
self.down_convs.append(GCNConv(channels, int(channels), improved=True))
in_channels = channels if sum_res else 2 * channels
self.up_convs = torch.nn.ModuleList()
for i in range(depth - 1):
self.up_convs.append(GCNConv(channels, int(channels/2), improved=True))
channels=int(channels/2)
self.up_convs.append(GCNConv(channels, out_channels, improved=True))
#self.reset_parameters()
self.muLay=torch.nn.Linear(channels, int(channels*2))
self.sigLay=torch.nn.Linear(channels, int(channels*2)
self.dec=torch.nn.Linear(int(channels*2),channels)
self.drop=torch.nn.Dropout(p=0.3)
def encode(self,x,edge_index,batch):
if batch is None:
batch = edge_index.new_zeros(x.size(0))
edge_weight = x.new_ones(edge_index.size(1))
x = self.down_convs[0](x, edge_index, edge_weight)
x = self.act(x)
xs = [x]
edge_indices = [edge_index]
edge_weights = [edge_weight]
perms = []
for i in range(1, self.depth + 1):
edge_index, edge_weight = self.augment_adj(edge_index, edge_weight,x.size(0))
x, edge_index, edge_weight, batch, perm, _ = self.pools[i - 1](
x, edge_index, edge_weight, batch)
x=self.drop(x)
x = self.down_convs[i](x, edge_index, edge_weight)
x = self.act(x)
if i < self.depth: ## < 3
xs += [x]
edge_indices += [edge_index]
edge_weights += [edge_weight]
perms += [perm]
mu2=self.muLay(x)
sig2=self.sigLay(x)
return mu2,sig2,xs,edge_indices,edge_weights,perms
def reparametrize(self, mu, logvar):
if self.training:
return mu + torch.randn_like(logvar) * torch.exp(logvar)
else:
return mu
def decode(self,z,xs,edge_indices,edge_weights,perms):
z=self.dec(z)
for i in range(self.depth):
j = self.depth - 1 - i
res = xs[j]
edge_index = edge_indices[j]
edge_weight = edge_weights[j]
perm = perms[j]
up = torch.zeros_like(res)
up[perm] = z
z = res + up if self.sum_res else torch.cat((res, up), dim=-1)
#print(z.shape)
z = self.up_convs[i](z, edge_index, edge_weight)
z = self.act(z) if i < self.depth - 1 else z
return z
def augment_adj(self, edge_index, edge_weight, num_nodes):
edge_index, edge_weight = add_self_loops(edge_index, edge_weight,
num_nodes=num_nodes)
edge_index, edge_weight = sort_edge_index(edge_index, edge_weight,num_nodes)
edge_index, edge_weight = spspmm(edge_index, edge_weight, edge_index,
edge_weight, num_nodes, num_nodes,num_nodes)
edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)
return edge_index, edge_weight
def forward(self,x,adj,lengs):
mu2,sig2,xs,edge_indices,edge_weights,perms= self.encode(x,adj,lengs)
z = self.reparametrize(mu2,sig2) ## z = mu + eps*sigma
z2=self.decode(z,xs,edge_indices,edge_weights,perms)
return z2, mu2, sig2
I am working on a Graph Variational Auto-Encoder to reproduce node features. I was wondering whether a modification of graph UNets could do so. Could I get your opinion on the following code? It's quite similar the only difference is that I split them into encoder/decoder functions, added two outputs for the encoder (mu and sigma) and used the KL Divergence in the loss function