tensorflow / nmt

TensorFlow Neural Machine Translation Tutorial
Apache License 2.0
6.35k stars 1.96k forks source link

Concatenating encoder_states #423

Open shehryar-malik opened 5 years ago

shehryar-malik commented 5 years ago

https://github.com/tensorflow/nmt/blob/b278487980832417ad8ac701c672b5c3dc7fa553/nmt/model.py#L782-L787 This for loop only seems to be appending encoder states in the correct order (e.g. [layer_1 forward state, layer_1 backward state, layer_2 forward state, layer_2 backward state, ...]). However, in order to pass on these states to the decoder, do we also not need to concatenate the forward and backward states together (something that the comment above the for loop also indicates), something along the following lines:

from tensorflow.contrib.rnn import LSTMStateTuple

def _concat(state_1, state_2):
    assert type(state_1) == type(state_2)
    if type(state_1) == LSTMStateTuple:
        return LSTMStateTuple(tf.concat([state_1[0], state_2[0]], -1), tf.concat([state_1[1], state_2[1]], -1))
    else:
        return tf.concat([state_1, state_2], -1)

encoder_state = []
for layer_id in range(num_bi_layers):
    encoder_state.append(_concat(bi_encoder_state[0][layer_id], bi_encoder_state[1][layer_id]))
encoder_state = tuple(encoder_state)