eaplatanios / tensorflow_scala

TensorFlow API for the Scala Programming Language
http://platanios.org/tensorflow_scala/
Apache License 2.0
937 stars 95 forks source link

problem saving a graph with Saver class #31

Closed lucataglia closed 6 years ago

lucataglia commented 6 years ago

I'm looking for a way to save a TF model and than restore for training just like the Python API allow me to do. Looking at the way for save the model:

#PYTHON
saver = tf.train.Saver()
saver.save(sess, "./model-saver/foo")

I need to do something similar in Scala. Reading the Saver class documentation of the Scala API read that I can do:

//FROM SCALA API DOCUMENTATION
// Using a slight abuse of notation for paths:
saver.save(session, "my-model", globalStep = 0) ==> filename: "my-model-0"
saver.save(session, "my-model", globalStep = 1000) ==> filename: "my-model-1000"

But I can not understand how create the saver object on which the saver method is called. I try to use the constructor of the Saver class but it's private. Looking at the code on GitHub I saw that there is a companion object that expose some method but I did't find a way to do what I need.

eaplatanios commented 6 years ago

@lucaRadicalbit You can use tf.saver(...) which takes exactly the same arguments as the constructor of the Saver class. :)

lucataglia commented 6 years ago

Maybe I am missing something really easy but I already try tf.saver and I got: tf-saver The compile error told about a Saveable. I really can not figured out what I'm doing wrong.

eaplatanios commented 6 years ago

@lucaRadicalbit That's because the saver constructor does not take a session argument. This is the signature:

def saver(
        saveables: Set[Saveable] = null, reshape: Boolean = false, sharded: Boolean = false, maxToKeep: Int = 5,
        keepCheckpointEveryNHours: Float = 10000.0f, restoreSequentially: Boolean = false, filename: String = "model",
        builder: SaverDefBuilder = DefaultSaverDefBuilder, allowEmpty: Boolean = false,
        writerVersion: WriterVersion = V2, saveRelativePaths: Boolean = false, padGlobalStep: Boolean = false,
        name: String = "Saver"): Saver

So, if you want it to save all trainable variables (the default), all you need to do is:

val savePath = Paths.get("./my-model")
val saver = tf.saver()
// After having created the saver you can save using:
saver.save(session, savePath)

And you can reuse that saver object across different sessions too. :)

lucataglia commented 6 years ago

@eaplatanios Thank you very much, it works !!

eaplatanios commented 6 years ago

@lucaRadicalbit That's great to hear! :)

eaplatanios commented 6 years ago

@sujitbiswas If you use the Estimator, then you can add a CheckpointSaverHook and that will handle saving and restoring state while training for you automatically. :)

On Oct 24, 2017, 10:26 AM -0400, sujitbiswas notifications@github.com, wrote:

@eaplatanios related question, how to use saver with Estimator? — You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub, or mute the thread.