openai / coinrun

Code for the paper "Quantifying Transfer in Reinforcement Learning"
https://blog.openai.com/quantifying-generalization-in-reinforcement-learning/
MIT License
388 stars 87 forks source link

batch_norm #17

Closed brantbzhang closed 5 years ago

brantbzhang commented 5 years ago

int ppo2.py have tf.get_collection(tf.GraphKeys.UPDATE_OPS) but not use like:

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
    train_op = optimizer.minimize(loss)

method step and value in CnnPolicy, why not set batch_norm(is_training=False)

kcobbe commented 5 years ago

This code only supports running batch normalization using the statistics of the current batch (is_training = True), as this is what is effective at training time. You're correct that if you add the UPDATE_OPS dependencies to the graph you'll be able to run batch normalization using a moving average of those statistics (is_training = False), which will usually slightly improve test time performance.