tensorflow / fold

Deep learning with dynamic computation graphs in TensorFlow
Apache License 2.0
1.82k stars 266 forks source link

Error when trying to implement a N-aryTreeLSTM #71

Closed bdqnghi closed 6 years ago

bdqnghi commented 6 years ago

I'm trying to implement a more generalized version of the BinaryTreeLSTM example, which can be use for any N-aryTreeLSTM

This is the implementation:

class NaryTreeLSTMCell(tf.contrib.rnn.BasicLSTMCell):

  def __init__(self, num_units, forget_bias=1.0, activation=tf.tanh,
               keep_prob=1.0, seed=None):

    super(NaryTreeLSTMCell, self).__init__(
        num_units, forget_bias=forget_bias, activation=activation)
    self._keep_prob = keep_prob
    self._seed = seed

  def __call__(self, inputs, state, scope=None):
    with tf.variable_scope(scope or type(self).__name__):
      c_list = list()
      h_list = list()
      for child in state:
        c,h = child
        c_list.append(c)
        h_list.append(h)

      concat = list()
      concat.append(inputs)
      concat.extend(h_list)

      concat = tf.contrib.layers.linear(tf.concat(concat, 1), 5 * self._num_units)

      len_state = len(state)

      params = tf.split(value=concat, num_or_size_splits=3+len(state), axis=1)

      # i = input_gate, j = new_input, f = forget_gate, o = output_gate
      i = params[0]
      j = params[1]
      o = params[3+len(state)-1]
      f_list = params[2:(3+len(state)-1)]

      j = self._activation(j)
      if not isinstance(self._keep_prob, float) or self._keep_prob < 1:
        j = tf.nn.dropout(j, self._keep_prob, seed=self._seed)

      new_c = c_list[0] * tf.sigmoid(f_list[0] + self._forget_bias)
      for i,c in enumerate(c_list):
        if i > 0:
          new_c += c * tf.sigmoid(f_list[i] + self._forget_bias)

      new_c += tf.sigmoid(i) * j
      new_h = self._activation(new_c) * tf.sigmoid(o)

      new_state = tf.contrib.rnn.LSTMStateTuple(new_c, new_h)

      return new_h, new_state

I get this error: TypeError: Value passed to parameter 'x' has DataType int32 not in list of allowed values: float16, float32, float64, complex64, complex128

on this line :

new_c += tf.sigmoid(i) * j

Any idea to fix this?

pklfz commented 6 years ago

the i in tf.sigmoid(i) has referred to the i in for i,c in enumerate(c_list):, which has DataType int32.

bdqnghi commented 6 years ago

Ok I fixed it, thanks