google-deepmind / graph_nets

Build Graph Nets in Tensorflow
https://arxiv.org/abs/1806.01261
Apache License 2.0
5.34k stars 783 forks source link

aggregate node from sending nodes directly #110

Closed TianrenWang closed 3 years ago

TianrenWang commented 4 years ago

What is the correct way to implement node aggregation from all sending nodes? Let's say I have nodes A, B, C, D, and E. A and B send to C, and D sends to E. I just want C = f(A) + f(B) and D = f(c). Edges are ignored in what I want to do.

The way that I have implemented it so far is with an InteractionNetwork:

graph_network = modules.InteractionNetwork(
        edge_model_fn=lambda: snt.nets.MLP(output_sizes=[1]),
        node_model_fn=lambda: snt.nets.MLP(output_sizes=[depth]))

    global_block = blocks.GlobalBlock(global_model_fn=lambda: snt.nets.MLP(output_sizes=[depth]))

    num_recurrent_passes = FLAGS.recurrences
    previous_graphs = batch_of_graphs

    for unused_pass in range(num_recurrent_passes):
        previous_graphs = graph_network(previous_graphs)
        previous_graphs = global_block(previous_graphs)

The output size of edge model is 1 because all edges are just tf.constant([1]). The reason why I am asking is because my graph neural network is stuck on a loss value and I am wondering whether this graph neural network is implemented properly.

EDIT: Actually, I figured out the reason why my loss is stuck, but please verify that this is the correct way to do it anyways.

alvarosg commented 4 years ago

A and B send to C, and D sends to E. I just want C = f(A) + f(B) and D = f(c)

Do you mean: I just want C = f(A) + f(B) and E = f(D) ?

Otherwise I am not sure I get what you want to do.

TianrenWang commented 4 years ago

A and B send to C, and D sends to E. I just want C = f(A) + f(B) and D = f(c)

Do you mean: I just want C = f(A) + f(B) and E = f(D) ?

Otherwise I am not sure I get what you want to do.

Apologies. Yes that was a typo.

alvarosg commented 4 years ago

In that case the computation you want is much simpler than the model you are trying to use, and it is essentially what it is usually referred to as a Graph Convolutional Network, because there is a single function "f" computed on the nodes, but there is no computation happening on the edges. This can be written bottom up in term of our broadcast and aggregation operators for the graphs in gn.blocks:

model_fn = snt.nets.MLP(...)

for unused_pass in range(num_recurrent_passes):
  # Update the node features with the function
  updated_nodes = model_fn(previous_graphs.nodes)
  temporary_graph = previous_graphs.replace(nodes=updated_nodes)

  # Send the node features to the edges that are being sent by that node. 
  nodes_at_edges = gn.blocks.broadcast_sender_nodes_to_edges(temporary_graph)
  temporary_graph = temporary_graph.replace(edges=nodes_at_edge)

  # Aggregate the all of the edges received by every node.
  nodes_with_aggregated_edges = gn.blocks.ReceivedEdgesToNodesAggregator(tf.math.unsorted_segment_sum)(temporary_graph)
  previous_graphs = previous_graphs.replace(nodes=nodes_with_aggregated_edges)

More information about the ops and building blocks in gn.blocks is available in our paper Relational inductive biases, deep learning, and graph networks.