vrenkens / nabu

Code for end-to-end ASR with neural networks, build with TensorFlow
MIT License
108 stars 43 forks source link

Adding recurrent batch normalization #52

Closed AzizCode92 closed 5 years ago

AzizCode92 commented 5 years ago

Based on https://arxiv.org/pdf/1603.09025.pdf I still have problem handling the update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) in the trainer.py and I would be so happy if anyone can help me fix this issue.

Following the tensorflow documentation,

Note: when training, the moving_mean and moving_variance need to be updated. By default the update ops are placed in tf.GraphKeys.UPDATE_OPS, so they need to be added as a dependency to the train_op. For example: x_norm = tf.layers.batch_normalization(x, training=training) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss)

I tried this in the trainer.py inside the update function but still it throws me error.


update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            apply_gradients_op = optimizer.apply_gradients(
            grads_and_vars=grads_and_vars,
            name='apply_gradients')
            #and do all other update ops
            update_op = tf.group(
                *([apply_gradients_op] + update_ops),
                name='update')

        return update_op

the error message is

File "/home/ubuntu/workspace/reproduce/jobs/nabu/nabu/neuralnetworks/trainers/trainer.py", line 769, in train
    outputs['training_summaries']])
  File "/home/ubuntu/anaconda3/envs/tensorflow_p27/lib/python2.7/site-packages/tensorflow/python/training/monitored_session.py", line 671, in run
    run_metadata=run_metadata)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p27/lib/python2.7/site-packages/tensorflow/python/training/monitored_session.py", line 1156, in run
    run_metadata=run_metadata)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p27/lib/python2.7/site-packages/tensorflow/python/training/monitored_session.py", line 1255, in run
    raise six.reraise(*original_exc_info)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p27/lib/python2.7/site-packages/tensorflow/python/training/monitored_session.py", line 1240, in run
    return self._sess.run(*args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p27/lib/python2.7/site-packages/tensorflow/python/training/monitored_session.py", line 1312, in run
    run_metadata=run_metadata)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p27/lib/python2.7/site-packages/tensorflow/python/training/monitored_session.py", line 1076, in run
    return self._sess.run(*args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p27/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 929, in run
    run_metadata_ptr)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p27/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1152, in _run
    feed_dict_tensor, options, run_metadata)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p27/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1328, in _do_run
    run_metadata)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p27/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1348, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: node train/update (defined at /home/ubuntu/workspace/reproduce/jobs/nabu/nabu/neuralnetworks/trainers/trainer.py:578)  has inputs from different frames. The input node train/Listener/features/layer1/BLSTM/bidirectional_rnn/fw/fw/while/fw/bnlstm_cell/bnlstm_cell/state/batch_normalization/AssignMovingAvg (defined at /home/ubuntu/workspace/reproduce/jobs/nabu/nabu/neuralnetworks/components/recurrent_batch.py:61)  is in frame 'train/Listener/features/layer1/BLSTM/bidirectional_rnn/fw/fw/while/while_context'. The input node train/apply_gradients/Assign (defined at /home/ubuntu/workspace/reproduce/jobs/nabu/nabu/neuralnetworks/trainers/trainer.py:569)  is in frame ''.
AzizCode92 commented 5 years ago

or should I just keep the code as you did in the trainer.py?

AzizCode92 commented 5 years ago

related issues : issue1 in the meanwhile I did modify the batch_norm function so I can take care manually of the moving mean/variance during inference,

# Thanks to https://github.com/OlavHN/bnlstm
def batch_norm(inputs, name_scope, is_training, epsilon=1e-3, decay=0.99):
    with tf.variable_scope(name_scope):
        size = inputs.get_shape().as_list()[1]

        scale = tf.get_variable(
            'scale', [size], initializer=tf.constant_initializer(0.1))
        offset = tf.get_variable('offset', [size])

        population_mean = tf.get_variable(
            'population_mean', [size],
            initializer=tf.zeros_initializer(), trainable=False)
        population_var = tf.get_variable(
            'population_var', [size],
            initializer=tf.ones_initializer(), trainable=False)
        batch_mean, batch_var = tf.nn.moments(inputs, [0])

        # The following part is based on the implementation of :
        # https://github.com/cooijmanstim/recurrent-batch-normalization
        train_mean_op = tf.assign(
            population_mean,
            population_mean * decay + batch_mean * (1 - decay))
        train_var_op = tf.assign(
            population_var, population_var * decay + batch_var * (1 - decay))

        if is_training is True:
            with tf.control_dependencies([train_mean_op, train_var_op]):
                return tf.nn.batch_normalization(
                    inputs, batch_mean, batch_var, offset, scale, epsilon)
        else:
            return tf.nn.batch_normalization(
                inputs, population_mean, population_var, offset, scale,
                epsilon)

But it did not work properly ( early convergence )

AzizCode92 commented 5 years ago

@vrenkens : do you have an idea how can I fix this issue?