eaplatanios / tensorflow_scala

TensorFlow API for the Scala Programming Language
http://platanios.org/tensorflow_scala/
Apache License 2.0
939 stars 95 forks source link

No way to convert variables to saveables for tf.saver #154

Closed Spiess closed 5 years ago

Spiess commented 5 years ago

To specify the variables to be saved for tf.saver they must be wrapped as Saveable objects. By default, tf.saver() wraps all global variables in the graph as VariableSaveable (a class unfortunately private to the ops package), however I have been unable to find a way to convert a specific set of variables to saveables to be able to selectively load variable values from different checkpoints.

eaplatanios commented 5 years ago

What if you try to cast each variable to a Saveable? For example: vars.map(_: Saveable), where vars is a set of variables. Does an implicit conversion happen if you do that?

Spiess commented 5 years ago

Unfortunately vars.map(_: Saveable) doesn't work. I've also tried variables.map(variable => variable.castTo[Saveable]) (no implicits found for parameter evidence), val variables: Set[Saveable] = Set(variable) and casting each variable individually.

eaplatanios commented 5 years ago

I see. So it is an implicit class but it was private. I removed the private modifier on master. Sorry about that! :)