dmlc / dgl

Python package built to ease deep learning on graph, on top of existing DL frameworks.
http://dgl.ai
Apache License 2.0
13.55k stars 3.01k forks source link

Subgraph of batched DGLGraph #1865

Closed zrqiao closed 3 years ago

zrqiao commented 4 years ago

🐛 Bug

Taking the subgraph of batched DGLGraphs will cause the batches to be merged, leading to incorrect behavior when performing dgl.sum_nodes.

To Reproduce

Steps to reproduce the behavior:

print(g.batch_num_nodes) # Prints the num of individual graphs
g1 = g.subgraph(filter)
g1.copy_from_parent()
print(g1.batch_num_nodes) # This instead prints a scalar quantity
res = dgl.sum_nodes(g1, "x") # Incorrect summation

Expected behavior

g1.batch_num_nodes should print out the number of nodes for individual graphs

Environment

classicsong commented 4 years ago

The batched information is not maintained after subgraph()

zrqiao commented 4 years ago

@classicsong Thanks! Then what is the recommended way to perform message passing/readout on subgraphs after batching?

classicsong commented 4 years ago

Is your progress: graphs -> batched-graphs -> subgraph -> message passing -> readout?

Is this minibatch training when you are using subgraphs.

classicsong commented 4 years ago

Can you try:

subgs = [subg for g.sugraph(filter) in g)
batched_sugb = dgl.batch(subgs)
...
jalexvig commented 2 years ago

FFR the way I did this was

graph.ndata["h"] = feat
individual_graphs = dgl.unbatch(graph)
node_ids = [g.nodes()[g.ndata["filter_attr"].bool()] for g in individual_graphs]
all_subgraphs = [dgl.node_subgraph(g, ids) for g, ids in zip(individual_graphs, node_ids)]
batched_subgraphs = dgl.batch(all_subgraphs)