Simsso / NIPS-2018-Adversarial-Vision-Challenge

Code, documents, and deployment configuration files, related to our participation in the 2018 NIPS Adversarial Vision Challenge "Robust Model Track"
MIT License
10 stars 1 forks source link

Resnet Trainer Classes #55

Closed FlorianPfisterer closed 6 years ago

FlorianPfisterer commented 6 years ago

This PR contains essentially two new files:

For testing purposes, mains/test_train.py has also been added.

The BaseTrainer class contains an abstract definition of how a trainer behaves. It is initialized with the model and the session and features a train() method that automagically runs the training and logs training and validation stats via tf.logging. Training and validation are fully customizable by using appropriate flags (see below).

The ResNetTrainer class is a concrete implementation. It provides the train_epoch and val_epoch methods that the BaseTrainer class calls and uses the AdamOptimizer to tune the model weights.

The following FLAGS are used by the trainer (declared in base_trainer.py):

As for the train/val batch sizes: this can be configured by specifying the batch_size parameter in the TinyImageNetPipeline's initializer.

When required in the future, the following features may be added to the trainer:

Merging this PR resolves #49.