Grzego / handwriting-generation

Implementation of handwriting generation with use of recurrent neural networks in tensorflow. Based on Alex Graves paper (https://arxiv.org/abs/1308.0850).
MIT License
520 stars 107 forks source link

Cannot Restore Pretrained Model #6

Open dsleo opened 6 years ago

dsleo commented 6 years ago

Hello,

First off, I would like to thank you for this amazing work. I was giving it a go hoping to directly use your pretrained model. However, I ran into the following error upon doing generate.py --text="this was generated by computer" --bias=1. Any help will be greatly appreciated. I think it might be a renaming of some variables in the checkpoint file ?

NotFoundError: Key rnnlm/multi_rnn_cell/cell_10/basic_lstm_cell/bias/Adam_1 not found in checkpoint
     [[Node: save/RestoreV2_17 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_arg_save/Const_0_0, save/RestoreV2_17/tensor_names, save/RestoreV2_17/shape_and_slices)]]

During handling of the above exception, another exception occurred:

NotFoundError                             Traceback (most recent call last)
/data.nfs/leo/handwriting-generation/generate.py in <module>()
    244 
    245 if __name__ == '__main__':
--> 246     main()

/data.nfs/leo/handwriting-generation/generate.py in main()
    132     with tf.Session(config=config) as sess:
    133         saver = tf.train.import_meta_graph(args.model_path + '.meta')
--> 134         saver.restore(sess, args.model_path)
    135 
    136         while True:

/data/dss-data-dir/code-envs/python/test_py3/lib/python3.6/site-packages/tensorflow/python/training/saver.py in restore(self, sess, save_path)
   1558     """
   1559     meta_graph.add_collection_def(meta_graph_def, key,
-> 1560                                   export_scope=export_scope)
   1561 

I am also running tensorflow 1.2.0

Grzego commented 6 years ago

Hi,

what is your OS? I used this code mostly on Windows 10 but I also tested it on Linux (Manjaro) and it works too. Although to be honest I run it with tensorflow 1.5 (as I upgraded recently).

I'm not sure what might be the problem here. Do you have an option to test tensorflow 1.5?

dsleo commented 6 years ago

Hi,

thanks for your reply. Sorry I should have clearly stated my OS. It's CentOS 7. I've upgraded to tensorflow 1.5 and now have the following similar error:

NotFoundError: Key rnnlm/multi_rnn_cell/cell_103/basic_lstm_cell/kernel not found in checkpoint
     [[Node: save/RestoreV2_42 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_arg_save/Const_0_0, save/RestoreV2_42/tensor_names, save/RestoreV2_42/shape_and_slices)]]

During handling of the above exception, another exception occurred:

NotFoundError                             Traceback (most recent call last)
/data.nfs/leo/handwriting-generation/generate.py in <module>()
    244 
    245 if __name__ == '__main__':
--> 246     main()

/data.nfs/leo/handwriting-generation/generate.py in main()
    132     with tf.Session(config=config) as sess:
    133         saver = tf.train.import_meta_graph(args.model_path + '.meta')
--> 134         saver.restore(sess, args.model_path)
    135 
    136         while True:

/data/dss-data-dir/code-envs/python/test_py3/lib/python3.6/site-packages/tensorflow/python/training/saver.py in restore(self, sess, save_path)
   1558       logging.warning("TensorFlow's V1 checkpoint format has been deprecated.")
   1559       logging.warning("Consider switching to the more efficient V2 format:")
-> 1560       logging.warning("   `tf.train.Saver(write_version=tf.train.SaverDef.V2)`")
   1561       logging.warning("now on by default.")
   1562       logging.warning("*******************************************************")

/data/dss-data-dir/code-envs/python/test_py3/lib/python3.6/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
    893     try:
    894       result = self._run(None, fetches, feed_dict, options_ptr,
--> 895                          run_metadata_ptr)
    896       if run_metadata:
    897         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

/data/dss-data-dir/code-envs/python/test_py3/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
   1122     final_fetches = fetch_handler.fetches()
   1123     final_targets = fetch_handler.targets()
-> 1124     # We only want to really perform the run if fetches or targets are provided,
   1125     # or if the call is a partial run that specifies feeds.
   1126     if final_fetches or final_targets or (handle and feed_dict_tensor):

/data/dss-data-dir/code-envs/python/test_py3/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
   1319       # Ensure any changes to the graph are reflected in the runtime.
   1320       self._extend_graph()
-> 1321       with errors.raise_exception_on_not_ok_status() as status:
   1322         if self._created_with_new_api:
   1323           return tf_session.TF_SessionRun_wrapper(

/data/dss-data-dir/code-envs/python/test_py3/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
   1338         else:
   1339           return tf_session.TF_PRun(session, handle, feed_dict, fetch_list,
-> 1340                                     status)
   1341 
   1342     if handle is None:

Let me know if the log provided is not enough !

Grzego commented 6 years ago

This is strange, I just run this command in docker with CentOS7 and it works fine. Are you sure all files are intact?

dsleo commented 6 years ago

I've just clone the repo directly and hasn't touch the model files. I will try again later this week from scratch and let you know !

af258963 commented 6 years ago

@Grzego Thanks for sharing,looking forward to a Chinese version!