For testing purposes, mains/test_train.py has also been added.
The BaseTrainer class contains an abstract definition of how a trainer behaves. It is initialized with the model and the session and features a train() method that automagically runs the training and logs training and validation stats via tf.logging. Training and validation are fully customizable by using appropriate flags (see below).
The ResNetTrainer class is a concrete implementation. It provides the train_epoch and val_epoch methods that the BaseTrainer class calls and uses the AdamOptimizer to tune the model weights.
The following FLAGS are used by the trainer (declared in base_trainer.py):
learning_rate
num_epochs
As for the train/val batch sizes: this can be configured by specifying the batch_size parameter in the TinyImageNetPipeline's initializer.
When required in the future, the following features may be added to the trainer:
learning rate decay
early stopping mechanisms
instead of tf.logging, use an abstract Logger class that supports multiple formats (TensorBoard, comet.ml, console, ...)
This PR contains essentially two new files:
/trainer/base_trainer.py
/trainer/resnet_trainer.py
For testing purposes,
mains/test_train.py
has also been added.The BaseTrainer class contains an abstract definition of how a trainer behaves. It is initialized with the model and the session and features a
train()
method that automagically runs the training and logs training and validation stats viatf.logging
. Training and validation are fully customizable by using appropriate flags (see below).The ResNetTrainer class is a concrete implementation. It provides the
train_epoch
andval_epoch
methods that the BaseTrainer class calls and uses the AdamOptimizer to tune the model weights.The following FLAGS are used by the trainer (declared in
base_trainer.py
):learning_rate
num_epochs
As for the train/val batch sizes: this can be configured by specifying the
batch_size
parameter in theTinyImageNetPipeline
's initializer.When required in the future, the following features may be added to the trainer:
tf.logging
, use an abstract Logger class that supports multiple formats (TensorBoard, comet.ml, console, ...)Merging this PR resolves #49.