yongqyu / MolGAN-pytorch

Pytroch implementation of MolGAN: An implicit generative model for small molecular graphs (https://arxiv.org/abs/1805.11973)
164 stars 42 forks source link

RuntimeError: mat1 and mat2 shapes cannot be multiplied (144x70 and 69x128) #14

Open QinyuMa316 opened 9 months ago

QinyuMa316 commented 9 months ago

Does this happen to you too, when you run main.py ?

chengYu23 commented 5 months ago

我也出现了该问题,请问有什么解决方法吗

sotskopa commented 5 months ago

Problem

There seems to be a mixup between the atomic and bond features (m_dim and b_dim). In its forward pass, the Discriminator concatenates the node features with the output of the GraphConvolution and passes the result to the GraphAggregation as follows:

# models.py
h = self.gcn_layer(annotations, adj)
annotations = torch.cat((h, hidden, node) if hidden is not None else (h, node), -1)

h = self.agg_layer(annotations, torch.tanh)

Here, the node features have a dimensionality of m_dim, representing the different atom types. However, the GraphAggregation uses in_features + b_dim as the input size for its layers (where in_features is equal to the output dimensionality of the GraphConvolution) :

# layers.py
self.sigmoid_linear = nn.Sequential(nn.Linear(in_features + b_dim, out_features), nn.Sigmoid())
self.tanh_linear = nn.Sequential(nn.Linear(in_features + b_dim, out_features), nn.Tanh())

This mismatch causes shape errors.

Solution

Just use m_dim instead of b_dim in the GraphAggregation.

# models.py
class Discriminator(nn.Module):
    """Discriminator network with PatchGAN."""
    def __init__(self, conv_dim, m_dim, b_dim, dropout):
        super(Discriminator, self).__init__()

        graph_conv_dim, aux_dim, linear_dim = conv_dim
        # discriminator
        self.gcn_layer = GraphConvolution(m_dim, graph_conv_dim, b_dim, dropout)
-       self.agg_layer = GraphAggregation(graph_conv_dim[-1], aux_dim, b_dim, dropout)
+       self.agg_layer = GraphAggregation(graph_conv_dim[-1], aux_dim, m_dim, dropout)
# layers.py
class GraphAggregation(Module):

-    def __init__(self, in_features, out_features, b_dim, dropout):
+    def __init__(self, in_features, out_features, m_dim, dropout):
        super(GraphAggregation, self).__init__()
-        self.sigmoid_linear = nn.Sequential(nn.Linear(in_features+b_dim, out_features),
+        self.sigmoid_linear = nn.Sequential(nn.Linear(in_features+m_dim, out_features),
                                            nn.Sigmoid())
-        self.tanh_linear = nn.Sequential(nn.Linear(in_features+b_dim, out_features),
+        self.tanh_linear = nn.Sequential(nn.Linear(in_features+m_dim, out_features),
                                         nn.Tanh())
        self.dropout = nn.Dropout(dropout)

I hope this helps!