https://github.com/tensorflow/nmt/blob/b278487980832417ad8ac701c672b5c3dc7fa553/nmt/model.py#L782-L787 This for loop only seems to be appending encoder states in the correct order (e.g. [layer_1 forward state, layer_1 backward state, layer_2 forward state, layer_2 backward state, ...]). However, in order to pass on these states to the decoder, do we also not need to concatenate the forward and backward states together (something that the comment above the for loop also indicates), something along the following lines:
from tensorflow.contrib.rnn import LSTMStateTuple
def _concat(state_1, state_2):
assert type(state_1) == type(state_2)
if type(state_1) == LSTMStateTuple:
return LSTMStateTuple(tf.concat([state_1[0], state_2[0]], -1), tf.concat([state_1[1], state_2[1]], -1))
else:
return tf.concat([state_1, state_2], -1)
encoder_state = []
for layer_id in range(num_bi_layers):
encoder_state.append(_concat(bi_encoder_state[0][layer_id], bi_encoder_state[1][layer_id]))
encoder_state = tuple(encoder_state)
https://github.com/tensorflow/nmt/blob/b278487980832417ad8ac701c672b5c3dc7fa553/nmt/model.py#L782-L787 This for loop only seems to be appending encoder states in the correct order (e.g. [layer_1 forward state, layer_1 backward state, layer_2 forward state, layer_2 backward state, ...]). However, in order to pass on these states to the decoder, do we also not need to concatenate the forward and backward states together (something that the comment above the for loop also indicates), something along the following lines: