Currently training only works using the tf.Estimator framework. Some users might prefer using the standard sess.run way of calling the training operation for a more low level way of doing training. A starting point for that version might look like this
def main():
features, labels = input_fn.train_input_fn(tfrecord_path, batch_size=bs, shuffle_buffer_size=sbs)()
model = trainer.model_fn(features, labels, tf.estimator.ModeKeys.TRAIN)
train_op = model.train_op
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(num_epochs):
sess.run(train_op)
Currently training only works using the
tf.Estimator
framework. Some users might prefer using the standardsess.run
way of calling the training operation for a more low level way of doing training. A starting point for that version might look like this