google-deepmind / graph_nets

Build Graph Nets in Tensorflow
https://arxiv.org/abs/1806.01261
Apache License 2.0
5.34k stars 783 forks source link

What is the most efficient way to create a GraphTuple from batched tensors #120

Closed Joshuaalbert closed 4 years ago

Joshuaalbert commented 4 years ago

My use case is that I would like to use graph_nets to perform logical reasoning on the output of some prior batched TF operations. Let me present an example:

B,W,H,C = 10, 200,200,3
data = tf.random.normal(shape=(B, W,H,C))
#[B,W,H,8]
A = tf.keras.layers.Conv2D(8, (3,3), padding='same')(data)
# I would now like to now turn each "pixel" into its own node and add fully connected 
# edges. Thus I would like a GraphTuple representing a batch of B graphs with W*H 
# nodes with node attributes of shape [8], and (W*H)**2 edges (including self loops)

What I can imagine doing is creating a dynamic fully connect graph for the nodes of a single graph and then add offsets to the senders and recievers for the batches. Is there a prebuilt function for this usecase (which I imagine would be quite common)?

alvarosg commented 4 years ago

Thanks for your message!

There are several ways to do this.

In this case the number of nodes, and the batch size are fixed and known statically at compilation time (assuming the batch size is fixed), you can build a GraphsTuple representing a batch of sets with the number of nodes, and fully connect it using the provided utils_tf method:


# Note that for this part we don't need the node features, we just need to know that we will
# have `num_nodes` per graph.
num_nodes = W * H
data_dict = ["n_node": num_nodes]
graphs_tuple_with_sets_of_nodes = utils_tf.data_dicts_to_graphs_tuple(B*[data_dict])
graphs_tuple_with_connectivity = utils_tf.fully_connect_graph_static(
    graphs_tuple_with_sets_of_nodes, exclude_self_edges=False)

# You may even add some dummy edge features if you models requires it.
graphs_tuple_with_connectivity = set_zero_edge_features(
   graphs_tuple_with_connectivity, edge_size=1)

# And then just take your tensor, and merge the batch, width and height dimensions, 
# and put them as node features in the graphstuple:
data = tf.random.normal(shape=(B, W,H,C))
input_nodes = tf.reshape(data, [B*W*H, C])
graphs_tuple = graphs_tuple_with_connectivity.replace(nodes=input_nodes)

# Feed graphs_tuple to the model, and then after, recover you spatial dimensions.
output_graphs_tuple = some_graph_network(graphs_tuple)

output_spatial_data = tf.reshape(output_graphs_tuple.nodes,
                                                       [B, W, H, -1])

Note that you could also pass a data dict with some specific connectivity to data_dicts_to_graphs_tuple if you want to have more sophisticated edge patterns without worrying about adding offsets.

Joshuaalbert commented 4 years ago

Thanks @alvarosg, quite simple. Is there also a simple way to add specific edge attributes, which are not known a compile time?

In my case I have a well-defined distance between the nodes, that is consistent between batches. I guess the easiest way is to compute these distances in a matrix and then use the sender and receiver tensors modulo num_nodes to gather the correct elements.

alvarosg commented 4 years ago

You could do something like this (assuming you start from some node positions to calculate distances):

data_position= ...  # [batch_size, width, height, 2] containing x and y positions of each pixels
node_positions = tf.reshape(data_position, [B*W*H, 2])

graphs_tuple_with_connectivity = graphs_tuple_with_connectivity.replace(
    nodes=node_positions)

# This is just euclidian displacement, but it could also be any other distance.
edge_distances = (
    blocks.broadcast_receiver_nodes_to_edges(graphs_tuple_with_connectivity) - 
    blocks.broadcast_sender_nodes_to_edges(graphs_tuple_with_connectivity))
graphs_tuple = graphs_tuple_with_connectivity.replace(edges=edge_distances)
Joshuaalbert commented 4 years ago

Thanks! Closing now.

Joshuaalbert commented 4 years ago

BTW, for future searchers I'll leave this here:

def batched_tensor_to_fully_connected_graph_tuple_dynamic(nodes_tensor, pos=None, globals=None):
    """
    Convert tensor with batch dim to batch of GraphTuples.
    :param nodes_tensor: [B, num_nodes, F] Tensor to turn into nodes. F must be statically known.
    :param pos: [B, num_nodes, D] Tensor to calculate edge distance using difference. D must be statically known.
    :param globals: [B, G] Tensor to use as global. G must be statically known.
    :return: GraphTuple with batch of fully connected graphs
    """
    shape = tf.shape(nodes_tensor)
    batch_size, num_nodes = shape[0], shape[1]
    F = nodes_tensor.shape.as_list()[-1]
    graphs_with_nodes = GraphsTuple(n_node=tf.fill([batch_size], num_nodes),
                                    n_edge=tf.fill([batch_size], 0),
                                    nodes=tf.reshape(nodes_tensor, [batch_size * num_nodes, F]),
                                    edges=None, globals=None, receivers=None, senders=None)
    graphs_tuple_with_nodes_connectivity = utils_tf.fully_connect_graph_dynamic(
        graphs_with_nodes, exclude_self_edges=False)

    if pos is not None:
        D = pos.shape.as_list()[-1]
        graphs_with_position = graphs_tuple_with_nodes_connectivity.replace(
            nodes=tf.reshape(pos, [batch_size*num_nodes, D]))
        edge_distances = (
                blocks.broadcast_receiver_nodes_to_edges(graphs_with_position) -
                blocks.broadcast_sender_nodes_to_edges(graphs_with_position))
        graphs_with_nodes_edges = graphs_tuple_with_nodes_connectivity.replace(edges=edge_distances)
    else:
        graphs_with_nodes_edges = utils_tf.set_zero_edge_features(graphs_tuple_with_nodes_connectivity, 1, dtype=nodes_tensor.dtype)

    if globals is not None:
        graphs_with_nodes_edges_globals = graphs_with_nodes_edges.replace(globals=globals)
    else:
        graphs_with_nodes_edges_globals = utils_tf.set_zero_global_features(
            graphs_with_nodes_edges, global_size=1)

    return graphs_with_nodes_edges_globals