Open QinyuMa316 opened 9 months ago
我也出现了该问题,请问有什么解决方法吗
There seems to be a mixup between the atomic and bond features (m_dim
and b_dim
). In its forward pass, the Discriminator
concatenates the node features with the output of the GraphConvolution
and passes the result to the GraphAggregation
as follows:
# models.py
h = self.gcn_layer(annotations, adj)
annotations = torch.cat((h, hidden, node) if hidden is not None else (h, node), -1)
h = self.agg_layer(annotations, torch.tanh)
Here, the node features have a dimensionality of m_dim
, representing the different atom types. However, the GraphAggregation
uses in_features + b_dim
as the input size for its layers (where in_features
is equal to the output dimensionality of the GraphConvolution
) :
# layers.py
self.sigmoid_linear = nn.Sequential(nn.Linear(in_features + b_dim, out_features), nn.Sigmoid())
self.tanh_linear = nn.Sequential(nn.Linear(in_features + b_dim, out_features), nn.Tanh())
This mismatch causes shape errors.
Just use m_dim
instead of b_dim
in the GraphAggregation
.
# models.py
class Discriminator(nn.Module):
"""Discriminator network with PatchGAN."""
def __init__(self, conv_dim, m_dim, b_dim, dropout):
super(Discriminator, self).__init__()
graph_conv_dim, aux_dim, linear_dim = conv_dim
# discriminator
self.gcn_layer = GraphConvolution(m_dim, graph_conv_dim, b_dim, dropout)
- self.agg_layer = GraphAggregation(graph_conv_dim[-1], aux_dim, b_dim, dropout)
+ self.agg_layer = GraphAggregation(graph_conv_dim[-1], aux_dim, m_dim, dropout)
# layers.py
class GraphAggregation(Module):
- def __init__(self, in_features, out_features, b_dim, dropout):
+ def __init__(self, in_features, out_features, m_dim, dropout):
super(GraphAggregation, self).__init__()
- self.sigmoid_linear = nn.Sequential(nn.Linear(in_features+b_dim, out_features),
+ self.sigmoid_linear = nn.Sequential(nn.Linear(in_features+m_dim, out_features),
nn.Sigmoid())
- self.tanh_linear = nn.Sequential(nn.Linear(in_features+b_dim, out_features),
+ self.tanh_linear = nn.Sequential(nn.Linear(in_features+m_dim, out_features),
nn.Tanh())
self.dropout = nn.Dropout(dropout)
I hope this helps!
Does this happen to you too, when you run main.py ?