Open HH-66 opened 4 months ago
A temporary fix based deeprec2302 https://github.com/candyzone/DeepRec/commit/da651f08be0fb0ded8b22313f8ec48de4bccaca2
This issue is already fixed in release deeprec2402.
使用了partitioner后,问题仍然存在,用下面的代码可以复现 (2302版本)
测试代码说明:
测试模型经过设计使得它具有以下几个特点(具体实现方法参见model_fn函数)
训练的key依次为: 20步0, 40步1, 30步2, 10步3
训练命令: ./bare_minimum.py train
评测命令: ./bare_minimum.py eval --value 1 # 评测 key 为 1 的 embedding权重
修改checkpoint_dir/checkpoint文件,可以分别评测增量checkpoint和全量checkpoint
代码中的其它内容主要是为了让DeepRec能够在使用estimator api时也能正确生成、加载增量checkpoint
#!/usr/bin/env python3
import argparse
import functools
import os.path
import time
import tensorflow as tf
import numpy
global _incr_ckpt_secs
global _incr_ckpt_steps
class DelayHook(tf.train.SessionRunHook):
def after_run(self, run_context, run_values):
time.sleep(0.02)
def get_ev_option():
init_opt = tf.InitializerOption(initializer=tf.constant_initializer(1))
return tf.EmbeddingVariableOption(
init_option=init_opt, filter_option=None, evict_option=None)
def model_fn(features, labels, mode, params):
id_ = features['x']
weights = tf.get_embedding_variable(
name='embedding_table', embedding_dim=1,
value_dtype=tf.float32,
ev_option=get_ev_option(), key_dtype=tf.int64,
partitioner=tf.fixed_size_partitioner(num_shards=1))
x = tf.nn.embedding_lookup(weights, id_)
y = tf.reduce_mean(x, axis=1)
loss = tf.reduce_mean(y - labels)
saver = tf.train.Saver(
sharded=True, incremental_save_restore=True,
save_relative_paths=True)
scaffold = tf.train.Scaffold(saver=saver, incremental_save_restore=True)
if mode == tf.estimator.ModeKeys.EVAL:
return tf.estimator.EstimatorSpec(mode, loss=loss, scaffold=scaffold)
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
saver_hook = tf.train.CheckpointSaverHook(
incremental_save_secs=1, checkpoint_dir='checkpoint_dir',
save_steps=50, scaffold=scaffold, listeners=[])
log_hook = tf.train.LoggingTensorHook(
{'loss': loss, 'step': tf.train.get_or_create_global_step()},
every_n_iter=1)
minimize = optimizer.minimize(
loss, global_step=tf.train.get_or_create_global_step())
return tf.estimator.EstimatorSpec(
mode, loss=loss, train_op=minimize,
training_chief_hooks=[saver_hook], scaffold=scaffold,
training_hooks=[log_hook, DelayHook()])
def train_input_fn():
def generator():
for i in range(20):
features = {
'x': numpy.array([0], dtype=numpy.int64)
}
labels = numpy.zeros([1], dtype=numpy.float32)
yield features, labels
for i in range(40):
features = {
'x': numpy.array([1], dtype=numpy.int64),
}
labels = numpy.zeros([1], dtype=numpy.float32)
yield features, labels
for i in range(30):
features = {
'x': numpy.array([2], dtype=numpy.int64)
}
labels = numpy.zeros([1], dtype=numpy.float32)
yield features, labels
for i in range(10):
features = {
'x': numpy.array([3], dtype=numpy.int64)
}
labels = numpy.zeros([1], dtype=numpy.float32)
yield features, labels
return tf.data.Dataset.from_generator(
generator, output_types=({'x': tf.int64}, tf.float32),
output_shapes=({'x': tf.TensorShape([None])}, tf.TensorShape([None])))
def eval_input_fn(value):
def generator():
for i in range(10):
features = {
'x': numpy.array([value, value], dtype=numpy.int64)
}
labels = numpy.zeros([2], dtype=numpy.float32)
yield features, labels
return tf.data.Dataset.from_generator(
generator, output_types=({'x': tf.int64}, tf.float32),
output_shapes=({'x': tf.TensorShape([None])}, tf.TensorShape([None])))
def _patch_session_creator(checkpoint_dir):
tf.logging.info('Patching monitored_session.ChiefSessionCreator')
from tensorflow.python.training import monitored_session, checkpoint_management
monitored_session.ChiefSessionCreator__ = monitored_session.ChiefSessionCreator
monitored_session.ChiefSessionCreator = functools.partial(
_session_creator, checkpoint_dir=checkpoint_dir)
_patch_evaluate_and_export()
_patch_evaluate_recover_session()
def _create_session(*args, **kwargs):
if _incr_ckpt_secs is not None:
if 'save_incremental_checkpoint_secs' not in kwargs \
or kwargs['save_incremental_checkpoint_secs'] is None:
kwargs['save_incremental_checkpoint_secs'] = _incr_ckpt_secs
if _incr_ckpt_steps is not None:
if 'save_incremental_checkpoint_steps' not in kwargs \
or kwargs['save_incremental_checkpoint_steps'] is None:
kwargs['save_incremental_checkpoint_steps'] = _incr_ckpt_steps
tf.logging.info("Creating MonitoredTrainingSession, %s, %s", args, kwargs)
return tf.train.MonitoredTrainingSession__(*args, **kwargs)
def _session_creator(**kwargs):
from tensorflow.python.training import monitored_session
kwargs['checkpoint_filename_with_path'] = None
tf.logging.info('Creating ChiefSessionCreator: %s', kwargs)
return monitored_session.ChiefSessionCreator__(**kwargs)
def patch_incr_ckpt(secs=0, steps=0):
global _incr_ckpt_secs
global _incr_ckpt_steps
_incr_ckpt_secs = secs if secs > 0 else None
_incr_ckpt_steps = steps if steps > 0 else None
tf.logging.info("Patching MonitoredTrainingSession.")
from tensorflow.python.training import training
training.MonitoredTrainingSession__ = training.MonitoredTrainingSession
training.MonitoredTrainingSession = _create_session
tf.train.MonitoredTrainingSession__ = tf.train.MonitoredTrainingSession
tf.train.MonitoredTrainingSession = _create_session
def _patch_evaluate_and_export():
from tensorflow.python.training import checkpoint_management
from tensorflow_estimator.python.estimator.training import _EvalResult, _EvalStatus, _TrainingExecutor
from tensorflow.python.framework import ops
from tensorflow.python.eager import context
def evaluate_and_export(self):
tf.logging.info('custom evaluate_and_export')
latest_ckpt_path = self._estimator.latest_checkpoint()
if not latest_ckpt_path:
self._log_err_msg('Estimator is not trained yet. Will start an '
'evaluation when a checkpoint is ready.')
return _EvalResult(status=_EvalStatus.MISSING_CHECKPOINT), []
# .incremental_checkpoint
with context.graph_mode():
incremental_dir = os.path.join(self._estimator.model_dir, '.incremental_checkpoint')
incremental_ckpt = checkpoint_management.latest_checkpoint(incremental_dir)
base_version = int(latest_ckpt_path.split('-')[-1])
incremental_version = int(incremental_ckpt.split('-')[-1]) if incremental_ckpt else None
previous_version = int(self._previous_ckpt_path.split('-')[-1]) if self._previous_ckpt_path else None
tf.logging.info(f'now version: {base_version} {incremental_version} <- {previous_version}')
if previous_version and incremental_version and incremental_version == previous_version:
self._log_err_msg(
'No new checkpoint ready for evaluation. Skip the current '
'evaluation pass as evaluation results are expected to be same '
'for the same checkpoint.')
return _EvalResult(status=_EvalStatus.NO_NEW_CHECKPOINT), []
metrics = self._estimator.evaluate(
input_fn=self._eval_spec.input_fn,
steps=self._eval_spec.steps,
name=self._eval_spec.name,
checkpoint_path=latest_ckpt_path,
hooks=self._eval_spec.hooks)
# _EvalResult validates the metrics.
eval_result = _EvalResult(
status=_EvalStatus.EVALUATED,
metrics=metrics,
checkpoint_path=latest_ckpt_path)
is_the_final_export = (
eval_result.metrics[ops.GraphKeys.GLOBAL_STEP] >=
self._max_training_steps if self._max_training_steps else False)
export_results = self._export_eval_result(eval_result,
is_the_final_export)
if is_the_final_export:
tf.logging.debug('Calling exporter with the `is_the_final_export=True`.')
self._is_final_export_triggered = True
self._last_warning_time = 0
self._previous_ckpt_path = incremental_ckpt if incremental_ckpt else latest_ckpt_path
return eval_result, export_results
_TrainingExecutor._Evaluator.evaluate_and_export__ = _TrainingExecutor._Evaluator.evaluate_and_export
_TrainingExecutor._Evaluator.evaluate_and_export = evaluate_and_export
def _patch_evaluate_recover_session():
def recover_session(self,
master,
saver=None,
checkpoint_dir=None,
checkpoint_filename_with_path=None,
wait_for_checkpoint=False,
max_wait_secs=7200,
config=None):
from tensorflow.python.training import incremental_saver
incr_saver = incremental_saver._get_incremental_saver(self._incremental_save_restore, self._saver)
tf.logging.info("custom recover_session")
sess, is_loaded_from_checkpoint = self._restore_checkpoint(
master,
saver,
incr_saver,
checkpoint_dir=checkpoint_dir,
checkpoint_filename_with_path=checkpoint_filename_with_path,
wait_for_checkpoint=wait_for_checkpoint,
max_wait_secs=max_wait_secs,
config=config)
# Always try to run local_init_op
local_init_success, msg = self._try_run_local_init_op(sess)
if not is_loaded_from_checkpoint:
# Do not need to run checks for readiness
return sess, False
restoring_file = checkpoint_dir or checkpoint_filename_with_path
if not local_init_success:
tf.logging.info(
"Restoring model from %s did not make model ready for local init:"
" %s", restoring_file, msg)
return sess, False
is_ready, msg = self._model_ready(sess)
if not is_ready:
tf.logging.info("Restoring model from %s did not make model ready: %s",
restoring_file, msg)
return sess, False
tf.logging.info("Restored model from %s", restoring_file)
return sess, is_loaded_from_checkpoint
tf.train.SessionManager.recover_session__ = tf.train.SessionManager.recover_session
tf.train.SessionManager.recover_session = recover_session
def parse_cmdline():
p = argparse.ArgumentParser()
p.add_argument('mode', choices=('train', 'eval'))
p.add_argument('--value', type=int, default=0)
return p.parse_args()
def main():
cmdline = parse_cmdline()
tf.logging.set_verbosity(tf.logging.INFO)
patch_incr_ckpt(secs=1)
_patch_session_creator('checkpoint_dir')
eval_input = functools.partial(eval_input_fn, value=cmdline.value)
train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn)
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input)
config = tf.estimator.RunConfig(
model_dir='checkpoint_dir',
tf_random_seed=2020,
save_summary_steps=1,
save_checkpoints_steps=50,
keep_checkpoint_max=20,
experimental_max_worker_delay_secs=2000)
estimator = tf.estimator.Estimator(model_fn=model_fn, config=config)
if cmdline.mode == 'eval':
estimator.evaluate(eval_input)
else:
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
if __name__ == '__main__':
main()
System information
Describe the current behavior restore的时候加载 incremental_ckpt ev变量不能正确加载覆盖base里的ev变量
Describe the expected behavior 正确加载incr ev 覆盖对应的变量
Code to reproduce the issue
Provide a reproducible test case that is the bare minimum necessary to generate the problem.
Other info / logs
Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.