Open 0b01 opened 7 years ago
I have figured it out:
enc_cells_fw = []
for i in range(0, encoder_depth):
with tf.variable_scope('enc_RNN_{}'.format(i)):
cell = tf.contrib.rnn.LSTMCell(hidden_dim) # Or LSTMCell(hidden_dim)
cell = tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=1.0-dropout)
enc_cells_fw.append(cell)
enc_cell_fw = tf.contrib.rnn.MultiRNNCell(enc_cells_fw, state_is_tuple=True)
enc_cells_bw = []
for i in range(0, encoder_depth):
with tf.variable_scope('enc_RNN_{}'.format(i)):
cell = tf.contrib.rnn.LSTMCell(hidden_dim) # Or LSTMCell(hidden_dim)
cell = tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=1.0-dropout)
enc_cells_bw.append(cell)
enc_cell_bw = tf.contrib.rnn.MultiRNNCell(enc_cells_bw, state_is_tuple=True)
init_state = enc_cell_fw.zero_state(batch_size=batch_size, dtype=tf.float32)
# transpose encoder inputs to time-major
enc_inp_t = tf.transpose(enc_inp, [1,0,2])
#
# der bi encoder
with tf.variable_scope('encoder-fw') as scope: # forward sequence
enc_output_fw, enc_states_fw = tf.scan(lambda (_, st_1), x : enc_cell_fw(x, st_1),
enc_inp_t, initializer=(tf.zeros(shape=[batch_size, hidden_dim]), init_state))
with tf.variable_scope('encoder-bw') as scope: # backward sequence
enc_output_bw, enc_states_bw = tf.scan(lambda (_, st_1), x : enc_cell_bw(x, st_1),
tf.reverse(enc_inp_t, axis=[0]), # <- reverse inputs
initializer=(tf.zeros(shape=[batch_size, hidden_dim]), init_state))
enc_output_fw = tf.transpose(enc_output_fw, [1,0,2])
enc_output_bw = tf.transpose(enc_output_bw, [1,0,2])
encoder_outputs = tf.concat([enc_output_fw, enc_output_bw], 2)
# project context
Wc = tf.get_variable('Wc', shape=[2, encoder_depth, hidden_dim*2, hidden_dim*2],
initializer=tf.contrib.layers.xavier_initializer())
# extract context [get final state; project c,h to [hidden_dim]; list->tuple]
encoder_final_state = []
for layer in range(encoder_depth):
enc_c = tf.concat( (enc_states_fw[layer].c[-1], enc_states_bw[layer].c[-1]), 1)
enc_c = tf.matmul(enc_c, Wc[0][layer])
enc_h = tf.concat( (enc_states_fw[layer].h[-1], enc_states_bw[layer].h[-1]), 1)
enc_h = tf.matmul(enc_h, Wc[1][layer])
encoder_final_state.append(tf.contrib.rnn.LSTMStateTuple(c = enc_c, h = enc_h))
# convert list to tuple - eww!
encoder_final_state = tuple(encoder_final_state)
Hello @suriyadeepan, nice repo you got here. I am trying to resolve https://github.com/tensorflow/tensorflow/issues/10862 this bug. I find you functional scan method refreshing. My only question: how to get the outputs from the rnn?
For a typical bidirectional_dynamic_rnn, the return result is output and state: