In the TensorFlow v2 codepath, there is a bug preventing loading the checkpoint.
The bug is very clear in the code. In tensor2tensor/utils/contrib.py, in the absence of tensorflow.contrib, the method framework() will return a DummyModule. Then, in tensor2tensor/utils/t2t_model.py, we try to load the checkpoint via:
reader = contrib.framework().load_checkpoint(ckpt_dir)
variable_map = {}
for var in contrib.framework().get_trainable_variables():
I will open a pull request with my personal solution, and I am open to change it to best fit the project.
INFO:tensorflow:Checkpoint dir: /homes/ace01/forks/magenta/unconditional_model_16.ckpt
I0526 11:08:49.628702 140498416027456 t2t_model.py:2341] Checkpoint dir: /homes/ace01/forks/magenta/unconditional_model_16.ckpt
Traceback (most recent call last):
File "/homes/ace01/forks/magenta/venv/bin/t2t_trainer", line 11, in <module>
load_entry_point('magenta', 'console_scripts', 't2t_trainer')()
File "/homes/ace01/forks/magenta/magenta/tensor2tensor/t2t_trainer.py", line 32, in console_entry_point
tf.app.run(main)
File "/homes/ace01/forks/magenta/venv/lib/python3.7/site-packages/tensorflow/python/platform/app.py", line 36, in run
_run(main=main, argv=argv, flags_parser=_parse_flags_tolerate_undef)
File "/homes/ace01/forks/magenta/venv/lib/python3.7/site-packages/absl/app.py", line 312, in run
_run_main(main, args)
File "/homes/ace01/forks/magenta/venv/lib/python3.7/site-packages/absl/app.py", line 258, in _run_main
sys.exit(main(argv))
File "/homes/ace01/forks/magenta/magenta/tensor2tensor/t2t_trainer.py", line 26, in main
t2t_trainer.main(argv)
File "/homes/ace01/non_forks/tensor2tensor/tensor2tensor/bin/t2t_trainer.py", line 419, in main
execute_schedule(exp)
File "/homes/ace01/non_forks/tensor2tensor/tensor2tensor/bin/t2t_trainer.py", line 372, in execute_schedule
getattr(exp, FLAGS.schedule)()
File "/homes/ace01/non_forks/tensor2tensor/tensor2tensor/utils/trainer_lib.py", line 469, in continuous_train_and_eval
self._eval_spec)
File "/homes/ace01/forks/magenta/venv/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/training.py", line 504, in train_and_evaluate
return executor.run()
File "/homes/ace01/forks/magenta/venv/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/training.py", line 645, in run
return self.run_local()
File "/homes/ace01/forks/magenta/venv/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/training.py", line 746, in run_local
saving_listeners=saving_listeners)
File "/homes/ace01/forks/magenta/venv/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 360, in train
loss = self._train_model(input_fn, hooks, saving_listeners)
File "/homes/ace01/forks/magenta/venv/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 1186, in _train_model
return self._train_model_default(input_fn, hooks, saving_listeners)
File "/homes/ace01/forks/magenta/venv/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 1215, in _train_model_default
self.config)
File "/homes/ace01/forks/magenta/venv/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 1174, in _call_model_fn
model_fn_results = self._model_fn(features=features, **kwargs)
File "/homes/ace01/non_forks/tensor2tensor/tensor2tensor/utils/t2t_model.py", line 1422, in wrapping_model_fn
use_tpu=use_tpu)
File "/homes/ace01/non_forks/tensor2tensor/tensor2tensor/utils/t2t_model.py", line 1549, in estimator_model_fn
loss, num_async_replicas=num_async_replicas, use_tpu=use_tpu)
File "/homes/ace01/non_forks/tensor2tensor/tensor2tensor/utils/t2t_model.py", line 1592, in estimator_spec_train
self.initialize_from_ckpt(self._hparams.warm_start_from)
File "/homes/ace01/non_forks/tensor2tensor/tensor2tensor/utils/t2t_model.py", line 1552, in initialize_from_ckpt
return initialize_from_ckpt(ckpt_dir=ckpt_dir, hparams=self._hparams)
File "/homes/ace01/non_forks/tensor2tensor/tensor2tensor/utils/t2t_model.py", line 2342, in initialize_from_ckpt
reader = contrib.framework().load_checkpoint(ckpt_dir)
AttributeError: 'DummyModule' object has no attribute 'load_checkpoint'
Description
In the TensorFlow v2 codepath, there is a bug preventing loading the checkpoint.
The bug is very clear in the code. In
tensor2tensor/utils/contrib.py
, in the absence oftensorflow.contrib
, the methodframework()
will return aDummyModule
. Then, intensor2tensor/utils/t2t_model.py
, we try to load the checkpoint via:I will open a pull request with my personal solution, and I am open to change it to best fit the project.
...
Environment information
Reproduction notes
This is not a minimal reproduction -- this is simply how I encountered it. I am finetuning a musical note sequence model from Magenta.
I think the bug is obvious in the code, so I do not feel the need to provide a minimal repro, but I can if it is deemed necessary.
Error logs