google-deepmind / graph_nets

Build Graph Nets in Tensorflow
https://arxiv.org/abs/1806.01261
Apache License 2.0
5.34k stars 783 forks source link

GraphNetwork model checkpointing method? #123

Closed agopal42 closed 4 years ago

agopal42 commented 4 years ago

How do I checkpoint a GraphNetwork model? I tried using tf.train.Checkpoint but the GraphNetwork class doesn't seem to be trackable object. tf.train.Checkpoint throws the following exception:

ValueError: Checkpoint was expecting a trackable object (an object derived from TrackableBase), got <attention.GNGlobal object at 0x7f206c71a438>. If you believe this object should be trackable (i.e. it is part of the TensorFlow Python API and manages state), please open an issue.

Thanks!

alvarosg commented 4 years ago

See #83

agopal42 commented 4 years ago

Thanks! :)