tensorflow / nmt

TensorFlow Neural Machine Translation Tutorial
Apache License 2.0
6.36k stars 1.96k forks source link

Infer with GreedyEmbeddingHelper get Error #363

Open vhuytdt opened 6 years ago

vhuytdt commented 6 years ago

I try infer with GreedyEmbeddingHelper but not working


decoder_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units,forget_bias=True)

tgt_sos_id = tf.cast(1,tf.int32)
tgt_eos_id = tf.cast(2,tf.int32)

# Helper
start_tokens = tf.fill([batch_size], tgt_sos_id)
end_token = tgt_eos_id

helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(decoder_emb_inp,start_tokens, end_token)

projection_layer = Dense(7712, use_bias=False)

maximum_iterations = tf.round(tf.reduce_max(dec_train_inp_lengths * 2))
print(maximum_iterations)
# Decoder
decoder = tf.contrib.seq2seq.BasicDecoder(
    decoder_cell, helper, encoder_state)
print(decoder)

# Dynamic decoding
outputs, _ = tf.contrib.seq2seq.dynamic_decode(
    decoder, maximum_iterations=maximum_iterations,swap_memory=True)

train_prediction = outputs.sample_id

ERR
`ValueError: Input 0 of layer basic_lstm_cell_12 is incompatible with the layer: expected ndim=2, found ndim=3. Full shape received: [64, 64, 300]`