Closed lucataglia closed 6 years ago
@lucaRadicalbit Could you try replacing this:
global_step = tf.train.get_or_create_global_step()
with this:
global_step = tf.get_default_graph().get_collection(ops.GraphKeys.GLOBAL_STEP)[0]
It seems to be that sometimes the Python API will obtain the value of the global step as a tensor and sometimes it will obtain the variable itself. In my API, I always obtain the variable directly, because I think that makes more sense and agrees with the semantics of what the global step is.
@eaplatanios I tried but I got the same behaviour and the same protobuf structure
collection_def {
key: "variables"
value {
bytes_list {
value: "\n\rfoo_weights:0\022\022foo_weights/Assign(\001"
value: "\n\rglobal_step:0\022\022global_step/Assign(\001"
value: "\n\025foo_weights/Adagrad:0\022\032foo_weights/Adagrad/Assign(\001"
}
}
}
collection_def {
key: "global_step"
value {
node_list {
value: "global_step:0"
}
}
}
And that is while you're still using tf.get_variable(tf.GraphKeys.GLOBAL_STEP, shape=[], dtype=tf.int64, initializer=tf.zeros_initializer(), collections=[tf.GraphKeys.GLOBAL_VARIABLES, tf.GraphKeys.GLOBAL_STEP, tf], use_resource=True)
? Can you check that the global step variable you create is indeed a resource variable after you create it?
I think the tf.Variable global_step is correctly create as resource-based because when I debug the Scala execution I see the is_resource
flag True
The image show the call to the fromProto method done by the foreach executed by the parseCollectionDef method implemented by the VariableCollectionKey Trait
Wait, if you're getting to the fromProto
method in parseCollectionDef
then you shouldn't be getting the previous error. It should happen in the if-statement right before that line in parseCollectionDef
.
I'm just a bit confused as to when and where the error happens.
No because the problem isn't on the global_step variable
, but is on the global_step collection_def
. The python code save something inside the "variables" collection_def
and something inside the"global_step" collection_def
. I'd like to be more specific but I'm not a big expert on how Tensor Flow serialize the model.
Going back to Scala, The parsing of this collection is done correctly:
collection_def {
key: "variables"
value {
bytes_list {
value: "\n\rfoo_weights:0\022\022foo_weights/Assign(\001"
value: "\n\rglobal_step:0\022\022global_step/Assign(\001"
value: "\n\025foo_weights/Adagrad:0\022\032foo_weights/Adagrad/Assign(\001"
}
}
}
The parsing of that one give me the error:
collection_def {
key: "global_step"
value {
node_list {
value: "global_step:0"
}
}
}
When the API parse the global_step collection_def
are looking for a collection_def that has a bytes_list but in my protobuf there is a nodelist. I saw that when the `private[api] val registry = mutable.Map.empty[String, Key[]]is populated is added a
GLOBAL_STEP objectthat extends VariableCollectionKey but the concrete implementation of the parseCollectionDef of that trait looking for a collection_def that has a bytes_list while Python save the
global_step collection_def` as a node_list. I can't get why the two things are different
If I wasn't clear with my last comment I'll promise I'll write e more complete explanation as soon as I can
Oh I get it now, no more detailed explanation needed. It looks like the Python API saves the value of the variable in the "global_step" collection and the variable itself in the "variables" collection. That's weird and a bit inconsistent to be honest. I'll look into finding a way to save it as a variable from the Python side, because I feel it makes more sense.
I give you the link of the gist where I write the Python code I use. I hope all my comment are consistent but they should be: https://gist.github.com/lucaRadicalbit/5bf312e288e8fea94a36ffab1ed09a2b
In the code there are some line with also the comment of the error that I get
I also write on StackOverflow for asking some details about the difference between bytes_list and node_lis. I insert the link, maybe for someone can be useful once I get an answer: https://stackoverflow.com/questions/47816188/tensor-flow-metagraphdef-protobuf-difference-between-node-list-and-bytes-list
@lucaRadicalbit Is this still an issue?
@eaplatanios I don't know because I drop that approach and follow a new one. I don't know when but when I'll have time I'll make a try and I let you know.
@lucaRadicalbit Ok, thanks! In that case, I'll close this issue for now, as it's also mostly relevant to the TF Python API and not this project. Feel free to reopen if you keep having issues related to this. :)
@eaplatanios I wrote a Python script that does a distributed synchronous training using the Between-graph replication approach. I have 3 workers and 3 ps (don't know maybe this info is useful). I run that script, I did my training and then I saved the model using tf.train.CheckpointSaverHook. Here below a part of the protobuf generated by the saving operation:
Python side I must create a global_step variable in order to get the distributed training synchronicity. These two are the line of code I use to create that tf.Variable making sure it is a resource-based variable:
If I just use
tf.train.get_or_create_global_step()
method I'll get a reference-based variable. With these two lines of code I succeeded to create a resource-based variable (that is the kind of tf.Variable that the Scala API can handle) but during the restore operation, when the Scala API tries to parse the global_step collection_def, I got an error:The global_step collection should be stored as a byte list.
I understood why I got this error because in the protobuf I can see that the global_step collection is stored in a node_list. The GLOBAL_STEP object extends VariableCollectionKey
and inside the parseCollectionDef method is checked that
kind != CollectionDef.KindCase.BYTES_LIST
.Is the right one the trait that the GLOBAL_STEP object extends ? Because with this trait the concrete implementation of the parseCollectionDef method gives me the error above