eaplatanios / tensorflow_scala

TensorFlow API for the Scala Programming Language
http://platanios.org/tensorflow_scala/
Apache License 2.0
937 stars 95 forks source link

Restore Model: IllegalArgumentException: Cannot find graph collection key named 'variables'. #38

Closed lucataglia closed 6 years ago

lucataglia commented 6 years ago

I'm trying to restore a model inside a Scala environment and I get this error:

Exception in thread "main" java.lang.IllegalArgumentException: Cannot find graph collection key named 'variables'.
    at org.platanios.tensorflow.api.core.Graph$Keys$$anonfun$fromName$1.apply(Graph.scala:1083)
    at org.platanios.tensorflow.api.core.Graph$Keys$$anonfun$fromName$1.apply(Graph.scala:1083)
    at scala.collection.MapLike$class.getOrElse(MapLike.scala:128)
    at scala.collection.AbstractMap.getOrElse(Map.scala:59)
    at org.platanios.tensorflow.api.core.Graph$Keys$.fromName(Graph.scala:1083)
    at org.platanios.tensorflow.api.core.Graph$$anonfun$importMetaGraphDef$2.apply(Graph.scala:617)
    at org.platanios.tensorflow.api.core.Graph$$anonfun$importMetaGraphDef$2.apply(Graph.scala:614)
    at scala.collection.Iterator$class.foreach(Iterator.scala:891)
    at scala.collection.AbstractIterator.foreach(Iterator.scala:1334)
    at scala.collection.IterableLike$class.foreach(IterableLike.scala:72)
    at scala.collection.AbstractIterable.foreach(Iterable.scala:54)
    at org.platanios.tensorflow.api.core.Graph.importMetaGraphDef(Graph.scala:614)
    at org.platanios.tensorflow.api.ops.variables.Saver$.fromMetaGraphDef(Saver.scala:487)
    at linearRegression.LinearRegression_restoreModel4Training$.main(LinearRegression_restoreModel4Training.scala:27)
    at linearRegression.LinearRegression_restoreModel4Training.main(LinearRegression_restoreModel4Training.scala)

I am wondering, why do you try to get the key 'variables' and than check that the object is and UNBOUND_INPUTS ? I try to understand if I misunderstood how save a model in order to restore it to train again.

screen shot 2017-11-03 at 15 41 08

screen shot 2017-11-03 at 15 40 55

Anyway, do you have any ideas about the reason of this error ? Here the segment of my protobuf where there is the variables:

collection_def {
  key: "variables"
  value {
    bytes_list {
      value: "\n\tweights:0\022\016weights/Assign\032\016weights/read:0"
      value: "\n\021weights/Adagrad:0\022\026weights/Adagrad/Assign\032\026weights/Adagrad/read:0"
    }
  }
}

I got that using the debug mode on my IDE.

lucataglia commented 6 years ago

@eaplatanios Just a question, the filename.meta protobuf generated saving a TF model with your Scala API is different from a filename.meta protobuf generated saving an identical TF model but written with the Python API or if the models are exactly the same the two files will be identical ?

eaplatanios commented 6 years ago

@lucaRadicalbit That is unexpected. I can't seem to reproduce the error. The UNBOUND_INPUTS condition has to do with how to load unbound inputs from loop contexts and so it doesn't affect your case. It seems like the Graph.Keys object is not initialized correctly. What is the code you use to get that error?

eaplatanios commented 6 years ago

As for the filename.meta protobuf, there can be differences with the Python version as I sometimes name things differently, etc., but other than names, basic structure should be very similar and the Python API should be able to load models saved with my Scala API, and vice versa.

lucataglia commented 6 years ago

@eaplatanios Today I try both. I failed to load in Scala a model generate with Python an vice versa. If you try to load for example a Python model with your Scala API you don't have any problems or errors ?

eaplatanios commented 6 years ago

@lucaRadicalbit That's interesting. I want to figure this out, but in order to save some time, could you please send me the code you used (both Python and Scala), so I can re-produce the problem?

lucataglia commented 6 years ago

@eaplatanios I create a gist in order to share the four file. At the end of the two Python script there are also a debug print that you might want to comment. When I load in Python a model saved with Python API and when I load in Scala a model saved with Scala API everything seem to be ok. But when I load in Scala a model saved with Python API I get an error and when I load in Python a model saved with Scala API I get a strange values. At the Scala and Python script with I try to restore a model I pass the path to the folder where I saved the model with argv. Using this cat <filename> | protoc --decode_raw to analyse the content of the two my.model.meta files generated with the two API, Python and Scala, the two file seem different.

https://gist.github.com/lucaRadicalbit/e8addd67eb35c73148ab91c4e276a8c1

lucataglia commented 6 years ago

@eaplatanios With the code below I try to restore the model create with your LinearExample.scala (so Scala to Scala stuff here)

var path = Paths.get(stringPath + ".meta")
var checkpointPath = Paths.get(stringPath)
var mgf = MetaGraphDef.parseFrom(new BufferedInputStream(new FileInputStream(path.toFile)))

val session = Session()
val saver = tf.Saver.fromMetaGraphDef(mgf)
saver.restore(session, checkpointPath)

//tf.Tensor in Python is the equivalent of Output here in scala
val weights = session.graph.getOutputByName("foo_weights:0")
val inputs = session.graph.getOutputByName("foo_inputs:0")
val outputs = session.graph.getOutputByName("foo_output:0")
val loss = session.graph.getOutputByName("foo_loss:0")
val trainOp = session.graph.getOpByName("foo_train_op_1/foo_train_op")

println(weights)
println(inputs)
println(outputs)
println(loss)
println(trainOp)

for (i <- 0 to 50) {
            val trainBatch = batch(10000)
            val feeds = Map(inputs -> trainBatch._1, outputs -> trainBatch._2)
            val trainLoss = session.run(feeds = feeds, fetches = loss, targets = trainOp)

            logger.info(s"Train loss at iteration $i = ${trainLoss.scalar} ")
}

Why the tf.variable weight has RESOURCE as dataType ?

Output(name = foo_weights:0, shape = [], dataType = RESOURCE)
Output(name = foo_inputs:0, shape = [?, 1], dataType = FLOAT32)
Output(name = foo_output:0, shape = [?, 1], dataType = FLOAT32)
Output(name = foo_loss:0, shape = [], dataType = FLOAT32)
foo_train_op_1/foo_train_op

When I try to run a session in order to print the previous weight value before start a training I get an error. println("weight initial value: " + session.run(fetches = weights).scalar) I insert this line just before the for loop

Exception in thread "main" java.lang.UnsupportedOperationException: The resource data type is not supported on the Scala side.
    at org.platanios.tensorflow.api.types.RESOURCE$.getElementFromBuffer(DataType.scala:837)
    at org.platanios.tensorflow.api.types.RESOURCE$.getElementFromBuffer(DataType.scala:822)
    at org.platanios.tensorflow.api.tensors.Tensor.getElementAtFlattenedIndex(Tensor.scala:228)
    at org.platanios.tensorflow.api.tensors.Tensor.scalar(Tensor.scala:242)
    at linearRegression.LinearRegression_restoreModel4Training$.main(LinearRegression_restoreModel4Training.scala:59)
    at linearRegression.LinearRegression_restoreModel4Training.main(LinearRegression_restoreModel4Training.scala)

Is it something that will be done in some future development ?

eaplatanios commented 6 years ago

@lucaRadicalbit I'm really sorry for responding so late to this but I'm rushing for a deadline and I'm currently working on adding RNN support to the library.

Regarding your question, I think I now know what's the issue. It's not really a problem per se. TensorFlow supports two types of variables: (i) reference-typed which is what the Python API uses, and (ii) resource-backed which is what I am using. The TensorFlow team intends to eventually switch to resource-backed variables too and that's why I originally decided to go with that. They haven't done it yet, but it should happen at some point. That's probably what's causing the incompatibility.

Now, regarding the Scala to Scala case, try to run the following instead for the example you posted above: println("weight initial value: " + session.run(fetches = weights.value).scalar) The reason being that you are trying to obtain the scalar value of a pointer to a resource that holds the variable value. I will try to document this soon, but does it make sense for now?

lucataglia commented 6 years ago

@eaplatanios Don't worry, I understand the your rush. Regarding the scala2scala stuff I try to write the line of code that you suggest to me but I get a compile time error. It seem that the value field is present for the Variable type but not for the Output and OutputLike type. screen shot 2017-11-13 at 10 22 18

Regarding the other topic. Looking at the tensor flow documentation (https://www.tensorflow.org/api_docs/python/tf/get_variable) I found that with the Python API when I create a tf.variable I can specify with the use_resource flag if create a regular Variable or a ResourceVariable.

# Snipped of Python code
w = tf.get_variable("foo_weights", shape=[1, 1], initializer=tf.zeros_initializer, use_resource=True)

Correct me If I am wrong but I suppose that's what you were talking about in your last response (for me these are all new stuff). I try to set the flag to True, generate the model, and try to restore that model within the Scala environment but I'm still get the same error:

2017-11-13 10:48:32.486421: I tensorflow/core/platform/cpu_feature_guard.cc:137] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.2 AVX AVX2 FMA
Exception in thread "main" java.lang.IllegalArgumentException: Cannot find graph collection key named 'variables'.
    at org.platanios.tensorflow.api.core.Graph$Keys$$anonfun$fromName$1.apply(Graph.scala:1083)
    at org.platanios.tensorflow.api.core.Graph$Keys$$anonfun$fromName$1.apply(Graph.scala:1083)
    at scala.collection.MapLike$class.getOrElse(MapLike.scala:128)
    at scala.collection.AbstractMap.getOrElse(Map.scala:59)
    at org.platanios.tensorflow.api.core.Graph$Keys$.fromName(Graph.scala:1083)
    at org.platanios.tensorflow.api.core.Graph$$anonfun$importMetaGraphDef$2.apply(Graph.scala:617)
    at org.platanios.tensorflow.api.core.Graph$$anonfun$importMetaGraphDef$2.apply(Graph.scala:614)
    at scala.collection.Iterator$class.foreach(Iterator.scala:891)
    at scala.collection.AbstractIterator.foreach(Iterator.scala:1334)
    at scala.collection.IterableLike$class.foreach(IterableLike.scala:72)
    at scala.collection.AbstractIterable.foreach(Iterable.scala:54)
    at org.platanios.tensorflow.api.core.Graph.importMetaGraphDef(Graph.scala:614)
    at org.platanios.tensorflow.api.ops.variables.Saver$.fromMetaGraphDef(Saver.scala:487)
    at python2scala.LinearRegression_restoreModel$.main(LinearRegression_restoreModel.scala:30)
    at python2scala.LinearRegression_restoreModel.main(LinearRegression_restoreModel.scala)

There are maybe some articles about Tensor Flow internals that can you suggest me to read in order to understand better how can I handle this arguments ? For the moment I just read A Tour of TensorFlow and the tensor flow.org documentation and clearly I miss some theory behind this stuff.

lucataglia commented 6 years ago

@eaplatanios About the Exception in thread "main" java.lang.IllegalArgumentException: Cannot find graph collection key named 'variables'. problems I tried with the debug mode to avoid the exception that the getOrElse method insideGraph.Keys.fromName(name) raise. I manually modify the content of the val key with the valueunbound_inputs and that's works in order to restore a model trained with Python API. I even print the weight value and it was correct exactly what I expected. screen shot 2017-11-13 at 15 54 27 I don't know if that means something to you but I notice that the collection_defs load from the metaGraph are three: variables, trainable_variables and train_op. Inside theKeys.registry map instead there is just unbound_inputs. screen shot 2017-11-13 at 15 54 58 In order to restore all the collection_def should Keys.registry contain also an entry for variables, trainable_variables and train_op ? And, why avoiding completely the restore of the collection_def I succeeded restoring and training a model previously saved (with Python API) ?

eaplatanios commented 6 years ago

@lucaRadicalbit I'm currently looking into this. I think it has to do with some static initialization code that should have been called before fromName is called, but I'm looking into it.

Regarding the resource-based variables, what you say is correct and the use_resource flag should work. On second thought though, that shouldn't really affect loading the model. In either case, I'm looking into this now and will update you once I have something. :)

eaplatanios commented 6 years ago

@lucaRadicalbit This should be fixed now. It should now also give you an informative error if you haven't used use_resource=True in Python and you try to load a model in Scala. Otherwise, it should load it just fine. One last thing to check would be to save a model in Scala and load it in Python, but I didn't get the chance to do that.

lucataglia commented 6 years ago

@eaplatanios Now it works !!! I also try save model in Scala and load in Python and that works too. Only a clarification, when I use Python to generate the model . . .

# Python snipped
w = tf.get_variable("foo_weights", dtype=tf.float32, shape=[1, 1], initializer=tf.zeros_initializer, use_resource=True)

. . . and I restore that model with Scala API I have to use "foo_weights/Read/ReadVariableOp:0" if I want to print the previous weight value trained in Python

// Scala snipped
val weights = session.graph.getOutputByName("foo_weights/Read/ReadVariableOp:0")
session.run(fetches = weights).scalar

When I use Scala to generate the model . . .

// Scala snipped
val weights = tf.variable("foo_weights", FLOAT32, Shape(1, 1), tf.zerosInitializer)

. . . and I use Python to restore the model trained with Scala API I have to use "foo_weights/ReadVariable:0" if I want to print the previous weight value trained in Scala

# Python snipped
w = graph.get_tensor_by_name("foo_weights/ReadVariable:0")

I don't know if what I just said is something helpful for someone but I wrote this because I found a difference between the two String needed in order to load tensor from a meta graph and maybe someone else need this information. For the moment I just try with your linear regression example, in the future I think I'll try with different model too. If you think an example of storing and reloading model will be helpful to be insert inside the examples that you offer in your repository let me know that I'll make a pull request.

eaplatanios commented 6 years ago

@lucaRadicalbit I'm glad it works now! :)

Regarding the names, yes that is expected. I have implemented a somewhat cleaner naming scheme than the Python API removing some of the clutter. It's also quite hard to replicate the naming scheme exactly as you would need to do that for every single op definition (lots of the Python implementations for ops add custom name scopes -- such as "Read" in your example above). Do you think that's ok?

Regarding the example, yes that would be awesome! :) Thanks! Also, given you've already done most of that, could you submit an example of a Python API / Scala API interop through saving/loading models. Feel free to just add the python code alongside the Scala code for now and I can clean it up and restructure it after you push it. :)