google-deepmind / graph_nets

Build Graph Nets in Tensorflow
https://arxiv.org/abs/1806.01261
Apache License 2.0
5.34k stars 783 forks source link

How to save trained models? #83

Closed maguileracanon closed 4 years ago

maguileracanon commented 5 years ago

I would like to save my trained models for future uses, especially as I am considering building my PhD predictors based on this library. Is there any way to add the .save attribute to the demos/models.py models to achieve something similar to what it is explained here?
Many thanks

alvarosg commented 5 years ago

Because graph_nets is not built on top of Keras, saving and restoring models is slightly different. To save the model you can need to use a tf.train.Saver and saver.save, and to restore it, you should build the tensorflow graph in the same way, and then use saver.restore. See example below:

def get_input_graphs():
  # Some function that returns a graphs.GraphsTuple

def build_and_connect_model(input_graphs):
  graph_network = modules.GraphNetwork(
      edge_model_fn=lambda: snt.Linear(output_size=4),
      node_model_fn=lambda: snt.Linear(output_size=4),
      global_model_fn=lambda: snt.Linear(output_size=4))
  output_graphs = graph_network(input_graphs)
  return graph_network, output_graphs

def log_variables(sess, variables):
  vars_out = sess.run(variables)
  print([(var.name, var_out.flatten()[:][0]) 
         for var, var_out in zip(variables, vars_out)])

# Saving it.
tf.reset_default_graph()
input_graphs = get_input_graphs()
graph_net, output_graphs = build_and_connect_model(input_graphs)
initializer = tf.global_variables_initializer()

saver = tf.train.Saver()
with tf.Session() as sess:
  sess.run(initializer)
  saver.save(sess, "/tmp/model")
  log_variables(sess, graph_net.variables)

# Reloading it later.
tf.reset_default_graph()
input_graphs = get_input_graphs()
graph_net, output_graphs = build_and_connect_model(input_graphs)
saver = tf.train.Saver()
with tf.Session() as sess:
  saver.restore(sess, "/tmp/model")
  log_variables(sess, graph_net.variables)

Hope this helps!

maguileracanon commented 5 years ago

it does Thanks!

maguileracanon commented 4 years ago

I understand that the computation graphs work different in tensoflow 2. Would tf.train.Saver() still be valid for tensorflow 2 ?

alvarosg commented 4 years ago

Here's some examples for check-pointing and model storage using Sonnet 2 and TF2: https://github.com/deepmind/sonnet#tensorflow-checkpointing

maguileracanon commented 4 years ago

Thank you!