Open xeroCBW opened 2 years ago
The edge_set_num is the number of the edge set, which would be one as we only use one set of edges, which contains fully connected edges. The parameter setting in the gnn layer is based on the input_dim, dim and inter_dim variables.
Why not use nodes to create the gnn layer, but instead use edges.
self.gnn_layers = nn.ModuleList([ GNNLayer(input_dim, dim, inter_dim=dim+embed_dim, heads=1) for i in range(edge_set_num) ])