merantix / imitation-learning

Autonomous driving: Tensorflow implementation of the paper "End-to-end Driving via Conditional Imitation Learning"
https://medium.com/merantix/journey-from-academic-paper-to-industry-usage-cf57fe598f31
MIT License
89 stars 21 forks source link

Add support for other ways of training than using tf.Estimator #4

Open roberttorfason opened 5 years ago

roberttorfason commented 5 years ago

Currently training only works using the tf.Estimator framework. Some users might prefer using the standard sess.run way of calling the training operation for a more low level way of doing training. A starting point for that version might look like this

def main():
    features, labels = input_fn.train_input_fn(tfrecord_path, batch_size=bs, shuffle_buffer_size=sbs)()
    model = trainer.model_fn(features, labels, tf.estimator.ModeKeys.TRAIN)
    train_op = model.train_op

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for i in range(num_epochs):
            sess.run(train_op)