google-deepmind / graph_nets

Build Graph Nets in Tensorflow
https://arxiv.org/abs/1806.01261
Apache License 2.0
5.34k stars 783 forks source link

GraphTuple from batched tensors does not offset Senders/Receivers #135

Closed Pol-Zeimet closed 3 years ago

Pol-Zeimet commented 3 years ago

I am currently working on a project which shares some similarities with the code posted in this issue.

The Idea is to build a GraphsTuple from batches of nodes, edges, senders an receivers. Connected nodes an edges will be concatenated and, via linear layers, turned into new node embeddings. This all works as it should with a batch size of 1.

As soon as more graphs are involved, the GraphsTuple does not behave as I expected. Instead of aggregating nodes and edges according to the individual graphs, the aggregation interprets senders and receivers as if they only affect the first graph in the batch. i.e. if senders and receivers of graphs[1] contain "0", it will address nodes[0] of the combined GraphsTuple and therefore node[0] of graphs[0].

This absolutely makes sense, except I expected the senders and receivers to be automatically offset in this context due to the provided n_nodes and n_edges. i.e. senders of graph[1] becoming offset via senders = senders + n_nodes[0]

Or am I missing something obvious here?

Here is my code.

class graph_conv(keras.layers.Layer):
  def __init__(self, input_units, intermediate_units, node_shape, edge_shape, **kwargs,):
      super(graph_conv, self).__init__(**kwargs)
      self.node_shape = node_shape
      self.edges_shape = edge_shape

      self.node_layer_in = layers.Dense(input_units, activation='relu', name='node_layer_in')
      self.intermediate_node_layer = layers.Dense(intermediate_units, activation='relu', name = 'intermediate_node_layer')
      self.node_layer_out = layers.Dense(self.node_shape, activation='relu', name='node_layer_out')
      self.edge_layer = layers.Dense(self.edge_shape, activation='relu', name='edge_layer')

  def call(self, nodes, edges, senders, receivers):

      nodes_shape = nodes.shape
      edges_shape = edges.shape
      assert self.node_shape == nodes.shape.as_list()[-1]
      assert self.edges_shape == edges.shape.as_list()[-1]

      batch_size = tf.shape(nodes)[0]
      num_nodes = tf.constant(nodes_shape[1])
      num_edges =tf.constant(edges_shape[1])

      #building_graph
      combined_graphs_tuple = graphs.GraphsTuple(n_node =tf.fill([batch_size], num_nodes),
                                      n_edge=tf.fill([batch_size], num_edges),
                                      nodes=tf.reshape(nodes, [batch_size * num_nodes, self.node_shape]),
                                      edges=tf.reshape(edges, [batch_size * num_edges, self.edges_shape]),
                                      senders=tf.reshape(senders, [batch_size * num_edges]),
                                      receivers=tf.reshape(receivers, [batch_size * num_edges]),
                                      globals=None,
                                      )

      #Step1
      #concatenate connected nodes with connecting edge
      left = blocks.broadcast_receiver_nodes_to_edges(combined_graphs_tuple)
      center = combined_graphs_tuple.edges 
      right = blocks.broadcast_sender_nodes_to_edges(combined_graphs_tuple)
      concatenated = tf.concat([left,center,right], axis = 1)

      combined_graphs_tuple = combined_graphs_tuple.replace(
          edges=concatenated
          )

      #Step2
      #calculate hidden states via mlp on each concatenation
      embeddings =  self.node_layer_in(combined_graphs_tuple.edges)
      combined_graphs_tuple = combined_graphs_tuple.replace(
          edges = self.node_layer_out(embeddings))

      #Step3
      #aggregate idden states to create new node embeddings
      reducer = tf.math.unsorted_segment_mean
      combined_graphs_tuple = combined_graphs_tuple.replace(
        nodes= blocks.ReceivedEdgesToNodesAggregator(reducer=reducer)(combined_graphs_tuple))

      #Step4
      #calculate new value for edges from embedding
      combined_graphs_tuple = combined_graphs_tuple.replace(
          edges=self.edge_layer(combined_graphs_tuple.edges))

      #Step 5
      #break apart GraphsTuple into individual graphs

      nodes = tf.reshape(combined_graphs_tuple.nodes, [batch_size, num_nodes, self.node_shape])
      edges = tf.reshape(combined_graphs_tuple.edges, [batch_size, num_edges, self.edges_shape])
      senders =  tf.reshape(combined_graphs_tuple.senders, [batch_size, num_edges])
      receivers = tf.reshape(combined_graphs_tuple.receivers, [batch_size, num_edges])

      return nodes, edges, senders, receivers

  def compute_output_shape(self, input_shape):
    return input_shape

  def get_config(self):
    config = super().get_config().copy()
    config.update({
      'input_units':input_units,
      'intermediate_units':intermediate_units,
      'node_shape':node_shape,
      'edge_shape':edge_shape
    })
    return config

Maybe I am building my Tuple in the wrong way?

If you want to test it, you can look at and compare the output like this:

#node and edge shapes are fixed padded with 0.
#senders and receivers have fixed length and  are padded with -1
node_shape = some_value
edge_shape = some_value
input_units = some_value
intermediate_units = some_value

gc = graph_conv(input_units, intermediate_units, node_shape, edge_shape)

nodes, edges, senders, receivers = some_graph
nodes2, edges2, senders2, receivers2 = some_other_graph

nodes = tf.constant(np.asarray([nodes]))
edges=tf.constant(np.asarray([edges]))
senders=tf.constant(np.asarray([senders]))
receivers=tf.constant(np.asarray([receivers]))

nodes2 = tf.constant(np.asarray([nodes]))
edges2=tf.constant(np.asarray([edges]))
senders2=tf.constant(np.asarray([senders]))
receivers2=tf.constant(np.asarray([receivers]))

new_nodes,new_edges,new_senders,new_receivers = gc(nodes,edges,senders,receivers)

batched_nodes = tf.constant(np.asarray([nodes,nodes2]))
batched_edges=tf.constant(np.asarray([edges,edges2]))
batched_senders=tf.constant(np.asarray([senders,senders2]))
batched_receivers=tf.constant(np.asarray([receivers,receivers2]))

new_nodes_batched,new_edges_batched,new_senders_batched,new_receivers_batched = gc(batched_nodes,
                                                                            batched_edges,
                                                                            batched_senders,
                                                                            batched_receivers)

Running it with my data, the wirst output of new_nodes,new_edges,new_senders,new_receivers looks just as it should The second one however does not. The tensor for new_nodes_batched for example contains one array that looks just fine, the second one however is filled with 0

Sorry for the wall of text and thank you in advance :)

alvarosg commented 3 years ago

In principle when you build GraphsTuples on your own, you are in charge of adding the offsets yourself.

We are working on a beta version of a method that batches tf.Dataset with GraphsTuples correctly by adding the offsets, following similar logic toutils_tf.concat(..., axis=0)`(link), but it is not yet ready to be open sourced.

In the meantime, you may want to use the hidden method utils_tf._compute_stacked_offsets (link, and add the output of that to the senders and receivers before putting them in the graphs tuple.

alvarosg commented 3 years ago

Actually, since the number of edges and nodes seems to be known and fixed statically, you can just add:

num_previous_accumulated_nodes_per_graph = tf.range(0, batch_size) * num_nodes
offsets = tf.tile(num_previous_accumulated_nodes_per_graph[:, None], [1, num_edges])
offsets = tf.reshape(offsets, [-1])

to the senders and receivers. (where num_nodes and num_edges are the number of nodes and number of edges per graph.

Pol-Zeimet commented 3 years ago

The explains the behaviour of my code :) Than you for also posting a possible solution for the offset 🥇 I might have to adjust it to account for the -1 padding when applying it to my senders/receivers. But you really saved me a lot of time!

Quick edit for final code. Now working as expected.

      num_previous_accumulated_nodes_per_graph = tf.range(0, batch_size) * num_nodes
      offsets = tf.tile(num_previous_accumulated_nodes_per_graph[:, None], [1, num_edges])
      offsets = tf.reshape(offsets, [-1])

      senders = tf.reshape(senders, [batch_size * num_edges])
      receivers = tf.reshape(receivers, [batch_size * num_edges])

      senders_offset = tf.where(senders != -1, senders + offsets, senders)
      receivers_offset = tf.where(receivers != -1, receivers + offsets, receivers) 

      combined_graphs_tuple = graphs.GraphsTuple(n_node =tf.fill([batch_size], num_nodes),
                                      n_edge=tf.fill([batch_size], num_edges),
                                      nodes=tf.reshape(nodes, [batch_size * num_nodes, self.node_shape]),
                                      edges=tf.reshape(edges, [batch_size * num_edges, self.edges_shape]),
                                      senders=senders_offset,
                                      receivers=receivers_offset,
                                      globals=None,
                                      )