UKPLab / fever-2018-team-athene

Apache License 2.0
46 stars 16 forks source link

Cannot load pretrained sentence retrieval model #36

Open Martin36 opened 2 years ago

Martin36 commented 2 years ago

When trying to load the pretrained ESIM model for sentence retrieval I get the following error:

Exception has occurred: NotFoundError
Key encode_rnn/birnn/bidirectional_rnn/fw/basic_lstm_cell/bias not found in checkpoint
     [[Node: save/RestoreV2 = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]

Caused by op 'save/RestoreV2', defined at:
  File "/home/martin/anaconda3/envs/team-athene/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/home/martin/anaconda3/envs/team-athene/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/martin/.vscode/extensions/ms-python.python-2021.12.1559732655/pythonFiles/lib/python/debugpy/__main__.py", line 45, in <module>
    cli.main()
  File "/home/martin/.vscode/extensions/ms-python.python-2021.12.1559732655/pythonFiles/lib/python/debugpy/../debugpy/server/cli.py", line 444, in main
    run()
  File "/home/martin/.vscode/extensions/ms-python.python-2021.12.1559732655/pythonFiles/lib/python/debugpy/../debugpy/server/cli.py", line 285, in run_file
    runpy.run_path(target_as_str, run_name=compat.force_str("__main__"))
  File "/home/martin/anaconda3/envs/team-athene/lib/python3.6/runpy.py", line 263, in run_path
    pkg_name=pkg_name, script_name=fname)
  File "/home/martin/anaconda3/envs/team-athene/lib/python3.6/runpy.py", line 96, in _run_module_code
    mod_name, mod_spec, pkg_name, script_name)
  File "/home/martin/anaconda3/envs/team-athene/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/martin/fever-2018-team-athene/src/athene/retrieval/sentences/sentence_retrieval.py", line 246, in <module>
    main(model="esim")
  File "/home/martin/fever-2018-team-athene/src/athene/retrieval/sentences/sentence_retrieval.py", line 193, in main
    clf.restore_model(os.path.join(model_store_dir, "best_model.ckpt"))
  File "/home/martin/fever-2018-team-athene/src/athene/retrieval/sentences/deep_models/ESIM.py", line 438, in restore_model
    self._construct_graph()
  File "/home/martin/fever-2018-team-athene/src/athene/retrieval/sentences/deep_models/ESIM.py", line 211, in _construct_graph
    saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES))
  File "/home/martin/anaconda3/envs/team-athene/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 1338, in __init__
    self.build()
  File "/home/martin/anaconda3/envs/team-athene/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 1347, in build
    self._build(self._filename, build_save=True, build_restore=True)
  File "/home/martin/anaconda3/envs/team-athene/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 1384, in _build
    build_save=build_save, build_restore=build_restore)
  File "/home/martin/anaconda3/envs/team-athene/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 835, in _build_internal
    restore_sequentially, reshape)
  File "/home/martin/anaconda3/envs/team-athene/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 472, in _AddRestoreOps
    restore_sequentially)
  File "/home/martin/anaconda3/envs/team-athene/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 886, in bulk_restore
    return io_ops.restore_v2(filename_tensor, names, slices, dtypes)
  File "/home/martin/anaconda3/envs/team-athene/lib/python3.6/site-packages/tensorflow/python/ops/gen_io_ops.py", line 1463, in restore_v2
    shape_and_slices=shape_and_slices, dtypes=dtypes, name=name)
  File "/home/martin/anaconda3/envs/team-athene/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "/home/martin/anaconda3/envs/team-athene/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3392, in create_op
    op_def=op_def)
  File "/home/martin/anaconda3/envs/team-athene/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1718, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

NotFoundError (see above for traceback): Key encode_rnn/birnn/bidirectional_rnn/fw/basic_lstm_cell/bias not found in checkpoint
     [[Node: save/RestoreV2 = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]

During handling of the above exception, another exception occurred:

  File "/home/martin/fever-2018-team-athene/src/athene/retrieval/sentences/deep_models/ESIM.py", line 447, in restore_model
    self._saver.restore(self._session, path)
  File "/home/martin/fever-2018-team-athene/src/athene/retrieval/sentences/sentence_retrieval.py", line 193, in main
    clf.restore_model(os.path.join(model_store_dir, "best_model.ckpt"))
  File "/home/martin/fever-2018-team-athene/src/athene/retrieval/sentences/sentence_retrieval.py", line 246, in <module>
    main(model="esim")

My belief is that it may be due to a mismatch between the variables found in the tf.GraphKeys.TRAINABLE_VARIABLES and the ones found in the .cpkt file.

The following variables are found in the trainable variables:

Trainable variables:  [<tf.Variable 'encode_rnn/birnn/bidirectional_rnn/fw/basic_lstm_cell/kernel:0' shape=(428, 512) dtype=float32_ref>, <tf.Variable 'encode_rnn/birnn/bidirectional_rnn/fw/basic_lstm_cell/bias:0' shape=(512,) dtype=float32_ref>, <tf.Variable 'infer_rnn/birnn/bidirectional_rnn/fw/basic_lstm_cell/kernel:0' shape=(1152, 512) dtype=float32_ref>, <tf.Variable 'infer_rnn/birnn/bidirectional_rnn/fw/basic_lstm_cell/bias:0' shape=(512,) dtype=float32_ref>, <tf.Variable 'dense/kernel:0' shape=(1024, 256) dtype=float32_ref>, <tf.Variable 'dense/bias:0' shape=(256,) dtype=float32_ref>, <tf.Variable 'dense_1/kernel:0' shape=(256, 1) dtype=float32_ref>, <tf.Variable 'dense_1/bias:0' shape=(1,) dtype=float32_ref>]

And these are the variables from the .cpkt file:

Variables found in checkpoint file:  [('dense/bias', [256]), ('dense/kernel', [1024, 256]), ('dense_1/bias', [1]), ('dense_1/kernel', [256, 1]), ('h_encode_rnn/bidirectional_rnn/fw/basic_lstm_cell/bias', [512]), ('h_encode_rnn/bidirectional_rnn/fw/basic_lstm_cell/kernel', [428, 512]), ('h_infer_rnn/bidirectional_rnn/fw/basic_lstm_cell/bias', [512]), ('h_infer_rnn/bidirectional_rnn/fw/basic_lstm_cell/kernel', [1152, 512]), ('s_endode_rnn/bidirectional_rnn/fw/basic_lstm_cell/bias', [512]), ('s_endode_rnn/bidirectional_rnn/fw/basic_lstm_cell/kernel', [428, 512]), ('s_infer_rnn/bidirectional_rnn/fw/basic_lstm_cell/bias', [512]), ('s_infer_rnn/bidirectional_rnn/fw/basic_lstm_cell/kernel', [1152, 512])]

I am using the model located in model/esim_0/sentence_retrieval_ensemble/model1

Does anyone have any idea of how to fix this problem?

Martin36 commented 2 years ago

When I train the model from scratch and then loads it, it seems to work.