eaplatanios / tensorflow_scala

TensorFlow API for the Scala Programming Language
http://platanios.org/tensorflow_scala/
Apache License 2.0
936 stars 96 forks source link

Error while restoring model after distributed synchronous training #65

Closed lucataglia closed 6 years ago

lucataglia commented 6 years ago

@eaplatanios I wrote a Python script that does a distributed synchronous training using the Between-graph replication approach. I have 3 workers and 3 ps (don't know maybe this info is useful). I run that script, I did my training and then I saved the model using tf.train.CheckpointSaverHook. Here below a part of the protobuf generated by the saving operation:

collection_def {
  key: "variables"
  value {
    bytes_list {
      value: "\n\rfoo_weights:0\022\022foo_weights/Assign(\001"
      value: "\n\rglobal_step:0\022\022global_step/Assign\032\022global_step/read:0"
      value: "\n\025foo_weights/Adagrad:0\022\032foo_weights/Adagrad/Assign(\001"
    }
  }
}
collection_def {
  key: "trainable_variables"
  value {
    bytes_list {
      value: "\n\rfoo_weights:0\022\022foo_weights/Assign(\001"
    }
  }
}
collection_def {
  key: "global_step"
  value {
    node_list {
      value: "global_step:0"
    }
  }
}

Python side I must create a global_step variable in order to get the distributed training synchronicity. These two are the line of code I use to create that tf.Variable making sure it is a resource-based variable:

# I have to specify use_resource=True to get a resource-based variable
tf.get_variable(tf.GraphKeys.GLOBAL_STEP, shape=[], dtype=tf.int64, initializer=tf.zeros_initializer(), collections=[tf.GraphKeys.GLOBAL_VARIABLES, tf.GraphKeys.GLOBAL_STEP, tf], use_resource=True) 
global_step = tf.train.get_or_create_global_step()

If I just use tf.train.get_or_create_global_step() method I'll get a reference-based variable. With these two lines of code I succeeded to create a resource-based variable (that is the kind of tf.Variable that the Scala API can handle) but during the restore operation, when the Scala API tries to parse the global_step collection_def, I got an error: The global_step collection should be stored as a byte list.

I understood why I got this error because in the protobuf I can see that the global_step collection is stored in a node_list. The GLOBAL_STEP object extends VariableCollectionKey

screen shot 2017-12-14 at 17 15 49

and inside the parseCollectionDef method is checked that kind != CollectionDef.KindCase.BYTES_LIST.

screen shot 2017-12-14 at 17 20 24

Is the right one the trait that the GLOBAL_STEP object extends ? Because with this trait the concrete implementation of the parseCollectionDef method gives me the error above

eaplatanios commented 6 years ago

@lucaRadicalbit Could you try replacing this:

global_step = tf.train.get_or_create_global_step()

with this:

global_step = tf.get_default_graph().get_collection(ops.GraphKeys.GLOBAL_STEP)[0]

It seems to be that sometimes the Python API will obtain the value of the global step as a tensor and sometimes it will obtain the variable itself. In my API, I always obtain the variable directly, because I think that makes more sense and agrees with the semantics of what the global step is.

lucataglia commented 6 years ago

@eaplatanios I tried but I got the same behaviour and the same protobuf structure

collection_def {
  key: "variables"
  value {
    bytes_list {
      value: "\n\rfoo_weights:0\022\022foo_weights/Assign(\001"
      value: "\n\rglobal_step:0\022\022global_step/Assign(\001"
      value: "\n\025foo_weights/Adagrad:0\022\032foo_weights/Adagrad/Assign(\001"
    }
  }
}
collection_def {
  key: "global_step"
  value {
    node_list {
      value: "global_step:0"
    }
  }
}
eaplatanios commented 6 years ago

And that is while you're still using tf.get_variable(tf.GraphKeys.GLOBAL_STEP, shape=[], dtype=tf.int64, initializer=tf.zeros_initializer(), collections=[tf.GraphKeys.GLOBAL_VARIABLES, tf.GraphKeys.GLOBAL_STEP, tf], use_resource=True)? Can you check that the global step variable you create is indeed a resource variable after you create it?

lucataglia commented 6 years ago

I think the tf.Variable global_step is correctly create as resource-based because when I debug the Scala execution I see the is_resource flag True

screen shot 2017-12-14 at 18 22 05

The image show the call to the fromProto method done by the foreach executed by the parseCollectionDef method implemented by the VariableCollectionKey Trait

eaplatanios commented 6 years ago

Wait, if you're getting to the fromProto method in parseCollectionDef then you shouldn't be getting the previous error. It should happen in the if-statement right before that line in parseCollectionDef.

eaplatanios commented 6 years ago

I'm just a bit confused as to when and where the error happens.

lucataglia commented 6 years ago

No because the problem isn't on the global_step variable, but is on the global_step collection_def. The python code save something inside the "variables" collection_def and something inside the"global_step" collection_def. I'd like to be more specific but I'm not a big expert on how Tensor Flow serialize the model. Going back to Scala, The parsing of this collection is done correctly:

collection_def {
  key: "variables"
  value {
    bytes_list {
      value: "\n\rfoo_weights:0\022\022foo_weights/Assign(\001"
      value: "\n\rglobal_step:0\022\022global_step/Assign(\001"
      value: "\n\025foo_weights/Adagrad:0\022\032foo_weights/Adagrad/Assign(\001"
    }
  }
}

The parsing of that one give me the error:

collection_def {
  key: "global_step"
  value {
    node_list {
      value: "global_step:0"
    }
  }
}

When the API parse the global_step collection_def are looking for a collection_def that has a bytes_list but in my protobuf there is a nodelist. I saw that when the `private[api] val registry = mutable.Map.empty[String, Key[]]is populated is added aGLOBAL_STEP objectthat extends VariableCollectionKey but the concrete implementation of the parseCollectionDef of that trait looking for a collection_def that has a bytes_list while Python save theglobal_step collection_def` as a node_list. I can't get why the two things are different

lucataglia commented 6 years ago

If I wasn't clear with my last comment I'll promise I'll write e more complete explanation as soon as I can

eaplatanios commented 6 years ago

Oh I get it now, no more detailed explanation needed. It looks like the Python API saves the value of the variable in the "global_step" collection and the variable itself in the "variables" collection. That's weird and a bit inconsistent to be honest. I'll look into finding a way to save it as a variable from the Python side, because I feel it makes more sense.

lucataglia commented 6 years ago

I give you the link of the gist where I write the Python code I use. I hope all my comment are consistent but they should be: https://gist.github.com/lucaRadicalbit/5bf312e288e8fea94a36ffab1ed09a2b

In the code there are some line with also the comment of the error that I get

lucataglia commented 6 years ago

I also write on StackOverflow for asking some details about the difference between bytes_list and node_lis. I insert the link, maybe for someone can be useful once I get an answer: https://stackoverflow.com/questions/47816188/tensor-flow-metagraphdef-protobuf-difference-between-node-list-and-bytes-list

eaplatanios commented 6 years ago

@lucaRadicalbit Is this still an issue?

lucataglia commented 6 years ago

@eaplatanios I don't know because I drop that approach and follow a new one. I don't know when but when I'll have time I'll make a try and I let you know.

eaplatanios commented 6 years ago

@lucaRadicalbit Ok, thanks! In that case, I'll close this issue for now, as it's also mostly relevant to the TF Python API and not this project. Feel free to reopen if you keep having issues related to this. :)