Closed zrqiao closed 3 years ago
The batched information is not maintained after subgraph()
@classicsong Thanks! Then what is the recommended way to perform message passing/readout on subgraphs after batching?
Is your progress: graphs -> batched-graphs -> subgraph -> message passing -> readout?
Is this minibatch training when you are using subgraphs.
Can you try:
subgs = [subg for g.sugraph(filter) in g)
batched_sugb = dgl.batch(subgs)
...
FFR the way I did this was
graph.ndata["h"] = feat
individual_graphs = dgl.unbatch(graph)
node_ids = [g.nodes()[g.ndata["filter_attr"].bool()] for g in individual_graphs]
all_subgraphs = [dgl.node_subgraph(g, ids) for g, ids in zip(individual_graphs, node_ids)]
batched_subgraphs = dgl.batch(all_subgraphs)
🐛 Bug
Taking the subgraph of batched DGLGraphs will cause the batches to be merged, leading to incorrect behavior when performing
dgl.sum_nodes
.To Reproduce
Steps to reproduce the behavior:
Expected behavior
g1.batch_num_nodes
should print out the number of nodes for individual graphsEnvironment