jianlong-yuan / syncbn-tensorflow

Synchronized Multi-GPU Batch Normalization
67 stars 13 forks source link

error when using recurrent bidirectional batch normalization (multi-gpu) #3

Closed AzizCode92 closed 3 years ago

AzizCode92 commented 5 years ago

I have used your code in order to make recurrent batch normalization. My encoder is a pyramidal bidirectional lstm.

Here is the error message.

InvalidArgumentError (see above for traceback): Cannot colocate nodes 'train/Listener/features/layer0/BLSTM/bidirectional_rnn/bw/bw/while/batch_norm_lstm_cell/cond/NcclAllReduce_1' and 'train/gradients/train/Listener/features/layer0/BLSTM/bidirectional_rnn/bw/bw/while/batch_norm_lstm_cell/cond/mul_1_grad/Mul_1/Const' because no device type supports both of those nodes and the other nodes colocated with them.
Colocation Debug Info:
Colocation group had the following types and devices: 
NcclAllReduce: GPU 
Switch: CPU 
Enter: GPU CPU 
StackV2: 
Const: GPU CPU 

Colocation members and user-requested devices:
  train/gradients/train/Listener/features/layer0/BLSTM/bidirectional_rnn/bw/bw/while/batch_norm_lstm_cell/cond/mul_1_grad/Mul_1/Const (Const) /job:worker/task:7
  train/gradients/train/Listener/features/layer0/BLSTM/bidirectional_rnn/bw/bw/while/batch_norm_lstm_cell/cond/mul_1_grad/Mul_1/f_acc (StackV2) /job:worker/task:7
  train/gradients/train/Listener/features/layer0/BLSTM/bidirectional_rnn/bw/bw/while/batch_norm_lstm_cell/cond/mul_1_grad/Mul_1/Switch/Enter (Enter) /job:worker/task:7
  train/gradients/train/Listener/features/layer0/BLSTM/bidirectional_rnn/bw/bw/while/batch_norm_lstm_cell/cond/mul_1_grad/Mul_1/Switch (Switch) /job:worker/task:7
  train/Listener/features/layer0/BLSTM/bidirectional_rnn/bw/bw/while/batch_norm_lstm_cell/cond/NcclAllReduce_1 (NcclAllReduce) /job:worker/task:7

     [[Node: train/Listener/features/layer0/BLSTM/bidirectional_rnn/bw/bw/while/batch_norm_lstm_cell/cond/NcclAllReduce_1 = NcclAllReduce[T=DT_FLOAT, num_devices=7, reduction="sum", shared_name="Listener/f...ean_square", _device="/job:worker/task:7"](train/Listener/features/layer0/BLSTM/bidirectional_rnn/bw/bw/while/batch_norm_lstm_cell/cond/Mean_1)]]

Any idea how to fix it?

jianlong-yuan commented 5 years ago

more details?

AzizCode92 commented 5 years ago

I think that the causes of such behaviour is caused by the tf.while loop of the bidirectional lstm. Similar issue I have faced with tf.control_dependencies when it is used with bidirectional_lstm. In other words, the tf.while_loop and the operations of the NCLL are used at the same time and this makes tensorflow complains. For more details, my encoder is as following:

with tf.variable_scope(scope or 'BLSTM'):

        #create the lstm cell that will be used for the forward and backward
        #pass
        lstm_cell_fw = BNLSTMCell(
            is_training=is_training,
            num_units=num_units,
            reuse=tf.get_variable_scope().reuse)
        lstm_cell_bw = BNLSTMCell(
            is_training=is_training,
            num_units=num_units,
            reuse=tf.get_variable_scope().reuse)

        #do the forward computation
        outputs_tupple, _ = bidirectional_dynamic_rnn(
            lstm_cell_fw, lstm_cell_bw, inputs, dtype=tf.float32,
            sequence_length=sequence_length)

        outputs = tf.concat(outputs_tupple, 2)

        return outputs

and my lstm cell are as following

class BNLSTMCell(RNNCell):
    """Batch Normalized Basic LSTM recurrent network cell.
    The implementation is based on: http://arxiv.org/abs/1409.2329.
    We add forget_bias (default: 1) to the biases of the forget gate in order to
    reduce the scale of forgetting in the beginning of the training.
    It does not allow cell clipping, a projection layer, and does not
    use peep-hole connections: it is the basic baseline.
    For advanced models, please use the full LSTMCell that follows.
    """

    def __init__(self,is_training ,num_units, forget_bias=1.0, input_size=None,
                 state_is_tuple=True, reuse=None):
        """Initialize the basic LSTM cell.
        Args:
          num_units: int, The number of units in the LSTM cell.
          is_training: bool, set True when training.
          forget_bias: float, The bias added to forget gates (see above).
          input_size: Deprecated and unused.
          state_is_tuple: If True, accepted and returned states are 2-tuples of
            the `c_state` and `m_state`.  If False, they are concatenated
            along the column axis.  The latter behavior will soon be deprecated.
          reuse: (optional) Python boolean describing whether to reuse variables
            in an existing scope.  If not `True`, and the existing scope already has
            the given variables, an error is raised.
        """
        if not state_is_tuple:
            logging.warn("%s: Using a concatenated state is slower and will soon be "
                         "deprecated.  Use state_is_tuple=True.", self)
        if input_size is not None:
            logging.warn("%s: The input_size parameter is deprecated.", self)
        self._num_units = num_units
        self._forget_bias = forget_bias
        self._state_is_tuple = state_is_tuple
        self._reuse = reuse
        self._is_training=is_training

    @property
    def state_size(self):
        return (LSTMStateTuple(self._num_units, self._num_units)
                if self._state_is_tuple else 2 * self._num_units)

    @property
    def output_size(self):
        return self._num_units

    def __call__(self, inputs, state, scope=None,is_training=True):

        """Long short-term memory cell (LSTM) with Recurrent Batch Normalization."""
        #input_size = inputs.get_shape().as_list()[1]
        input_size = inputs.get_shape().with_rank(2)[1]
        if input_size.value is None:
            raise ValueError(
                "Could not infer input size from inputs.get_shape()[-1]")

        is_training = self._is_training

        with tf.variable_scope(scope or "batch_norm_lstm_cell", reuse=self._reuse):
            # Parameters of gates are concatenated into one multiply for
            # efficiency.
            if self._state_is_tuple:
                c_prev, h_prev = state
            else:
                c_prev, h_prev = tf.split(
                    value=state, num_or_size_splits=2, axis=1)

            W_xh = tf.get_variable('W_xh', shape=[input_size.value, 4 * self._num_units],
                                   initializer=orthogonal_initializer())
            W_hh = tf.get_variable('W_hh', shape=[self._num_units, 4 * self._num_units],
                                   initializer=orthogonal_initializer())
            bias = tf.get_variable('b', [4 * self._num_units],initializer=tf.constant_initializer(0.0))

            lstm_matrix_i = batch_norm(tf.matmul(inputs,W_xh),is_training=is_training,num_dev=7)

            lstm_matrix_r = tf.matmul(h_prev, W_hh)
            #lstm_matrix_r = BatchNorm(tf.matmul(h_prev,W_hh),is_training=is_training,num_dev=7,name_scope="lstm_matrix_r")

            lstm_matrix = tf.nn.bias_add(math_ops.add(lstm_matrix_i, lstm_matrix_r), bias)

            i, g, f, o = tf.split(value=lstm_matrix, num_or_size_splits=4, axis=1)

            c = (c_prev * tf.sigmoid(f + self._forget_bias) + tf.sigmoid(i) * tf.tanh(g))

            h = tf.tanh(c) * tf.sigmoid(o)
            #h = tf.sigmoid(o) * tf.tanh(BatchNorm(c, is_training=is_training,name_scope="C"))

            if self._state_is_tuple:
                new_state = LSTMStateTuple(c, h)
            else:
                new_state = tf.concat(values=[c, h], axis=1)
            return h, new_state

where batch_norm is a function that is adopted from your code.

jianlong-yuan commented 5 years ago

Each loop, you may change 'shared_name', such as 'NCCL1' 'NCCL2', thus NCLL will be different