I'm trying to implement a more generalized version of the BinaryTreeLSTM example, which can be use for any N-aryTreeLSTM
This is the implementation:
class NaryTreeLSTMCell(tf.contrib.rnn.BasicLSTMCell):
def __init__(self, num_units, forget_bias=1.0, activation=tf.tanh,
keep_prob=1.0, seed=None):
super(NaryTreeLSTMCell, self).__init__(
num_units, forget_bias=forget_bias, activation=activation)
self._keep_prob = keep_prob
self._seed = seed
def __call__(self, inputs, state, scope=None):
with tf.variable_scope(scope or type(self).__name__):
c_list = list()
h_list = list()
for child in state:
c,h = child
c_list.append(c)
h_list.append(h)
concat = list()
concat.append(inputs)
concat.extend(h_list)
concat = tf.contrib.layers.linear(tf.concat(concat, 1), 5 * self._num_units)
len_state = len(state)
params = tf.split(value=concat, num_or_size_splits=3+len(state), axis=1)
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
i = params[0]
j = params[1]
o = params[3+len(state)-1]
f_list = params[2:(3+len(state)-1)]
j = self._activation(j)
if not isinstance(self._keep_prob, float) or self._keep_prob < 1:
j = tf.nn.dropout(j, self._keep_prob, seed=self._seed)
new_c = c_list[0] * tf.sigmoid(f_list[0] + self._forget_bias)
for i,c in enumerate(c_list):
if i > 0:
new_c += c * tf.sigmoid(f_list[i] + self._forget_bias)
new_c += tf.sigmoid(i) * j
new_h = self._activation(new_c) * tf.sigmoid(o)
new_state = tf.contrib.rnn.LSTMStateTuple(new_c, new_h)
return new_h, new_state
I get this error:
TypeError: Value passed to parameter 'x' has DataType int32 not in list of allowed values: float16, float32, float64, complex64, complex128
I'm trying to implement a more generalized version of the BinaryTreeLSTM example, which can be use for any N-aryTreeLSTM
This is the implementation:
I get this error:
TypeError: Value passed to parameter 'x' has DataType int32 not in list of allowed values: float16, float32, float64, complex64, complex128
on this line :
new_c += tf.sigmoid(i) * j
Any idea to fix this?