google-deepmind / graph_nets

Build Graph Nets in Tensorflow
https://arxiv.org/abs/1806.01261
Apache License 2.0
5.34k stars 783 forks source link

Training on batches of GraphsTuples? #147

Open robertswil opened 2 years ago

robertswil commented 2 years ago

Let's say I want to train an LSTM or transformer on sequences of graphs using Sonnet2/TF2:

I want to represent the graphs in each sequence as one GraphsTuple, which means my batches are essentially an iterable of GraphsTuples, each with a variable number of graphs. This is great until it's time to get the input signature and compile the update step. It's unclear to me how to define the tensorspec for this type of input. Is my best route to subclass collections.namedtuple() similar to how you define a GraphsTuple, or can you suggest a more elegant solution?

Thanks

alvarosg commented 2 years ago

Thanks for your message!

There are two options here and hopefully at least one of them would work for you:

def update_fn(input_graph_sequence, ...)

  def loop_body(step_i,...):
     graph_step_i = input_graph_sequence.replace(
        nodes=input_graph_sequence.nodes[:, step_i], 
        edges=input_graph_sequence.edges[:, step_i])

  num_steps = input_graph_sequence.nodes.shape.as_list()[1]
  tf.scan(loop_body, tf.range(num_steps), ...)  
 ...

Hope this helps!

robertswil commented 2 years ago

This worked. Thanks @alvarosg !

robertswil commented 2 years ago

Follow-on issue:

I am passing batches to the model during training like so:

outputs = tf.convert_to_tensor([model(graphs_tuple) for graphs_tuple in inputs])

As a reminder, each batch is an iterable of GaphsTuples, and each GraphsTuple represents a sequence of graphs for one training data point.

The GraphIndependent object of the encoder block (EncodeProcessDecode) throws the error: AttributeError: 'GraphsTuple' object has no attribute 'replace'. Location of the error according to the stack trace is here in modules.GraphIndependent``._build().

Any ideas on how to solve?

alvarosg commented 2 years ago

Could you check the type of the object being passed to the model?

My guess is that the GraphsTuple input that your are passing is not actually a graphs.GraphsTuple, but some serialization library or something like that has transformed it into a namedtuple that looks the same, but is not actually the same class, and does not have the extra methods

A simple fix to get the right type just do:

graphs_tuple = graphs.GraphsTuple(*graphs_tuple) before passing it to EncodeProcessDecode, but it may be good to understand where the type gets messed up.

robertswil commented 2 years ago

Your hunch was correct! It was being transformed into collections.GraphsTuple during the list comprehension:

outputs = tf.convert_to_tensor([model(inputs) for inputs in inputs_train])

Is that^^ the preferred way to feed a batch of graph sequences during the update step? Wondering if it isn't, since when run graphs.GraphsTuple(*inputs) during training, I receive the following error:

OperatorNotAllowedInGraphError: iterating over 'tf.Tensor' is not allowed: AutGraph did convert this function. This might indicate you are trying to use an unsupported feature.

Checking a bit more, type(inputs) when this error is thrown is a symbolic tensor, not a GraphsTuple, which I guess means this happens before the backend actually runs the first batch through the model.