google / prettytensor

Pretty Tensor: Fluent Networks in TensorFlow
1.24k stars 151 forks source link

Reload the weights used in the model #21

Closed VigneshSrinivasan10 closed 8 years ago

VigneshSrinivasan10 commented 8 years ago

Hi,

This question is a post-op of the issue #6 .

How do we load the weights back into the model once they are saved. I had them saved this way.

  vars = sess.run(tf.get_collection(tf.GraphKeys.VARIABLES))
  pickle.dump(vars, open('vars.npy','wb'))

Thanks in advance!

eiderman commented 8 years ago

You need to use the saver built into tensorflow (tf.train.Saver) or the checkpointing built into pt.train.Runner (which uses the saver under the hood; set the save_path on creation).

Because of the deferred nature of TF, the variables in the collection are just pointers into the graph def and contain no data.