tensorflow / nmt

TensorFlow Neural Machine Translation Tutorial
Apache License 2.0
6.37k stars 1.96k forks source link

Inference,how togenerate translations? #323

Open ws309 opened 6 years ago

ws309 commented 6 years ago

I write my code according to the tutorials,I can train it now,but when i want to generate translations,the code can not get an answer ,it just keep running for ever. Here is my code,just as the tutorilas: SOS=99 EOS=16558 helper_infer=tf.contrib.seq2seq.GreedyEmbeddingHelper(embedding,tf.fill([batch_size],99),16558) infer_decode=tf.contrib.seq2seq.BasicDecoder(decoder_cell,helper_infer,decoder_initial_state,output_layer=projectionlayer) outputs,,__=tf.contrib.seq2seq.dynamic_decode(infer_decode,60) translations=outputs.sample_id trans=sess.run(translations,feed_dict=test_fd) test_fd={<tf.Tensor 'q_input:0' shape=(?, 30) dtype=int32>: array([[ 99, 99, 99, ..., 99, 99, 99], [ 0, 41, 13765, ..., 2046, 236, 0], [ 2, 0, 8719, ..., 5, 5, 0], ..., [ 0, 0, 0, ..., 0, 0, 0], [ 0, 0, 0, ..., 0, 0, 0], [ 0, 0, 0, ..., 0, 0, 0]]), <tf.Tensor 'q_seq_len:0' shape=(30,) dtype=int32>: array([ 6, 15, 9, 10, 14, 6, 8, 6, 8, 9, 7, 23, 8, 8, 7, 8, 7, 9, 7, 7, 6, 10, 24, 13, 7, 8, 6, 9, 12, 7])} #time_major At first I think i should add the SOS and EOS to the test_fd,but It still can not get answers,it just can not stop.There must be many tf.while_loop,but i don't know how to solve the question. Can anybody help me,thank you!

luozhouyang commented 6 years ago

The program keeps running means that it has not finish the training. And if you want to do inference, you can stop the training manually, and start the inference cmds:

python -m nmt.nmt \
    --infer_input_file=$INPUT_FILE \
    --infer_output_file=$OUTPUT_FILE \
    --ckpt=$CKPT_PATH \
    .. (other args)

You put your src seqs in $INPUT_FILE, and the result will be wrote to $OUTPUT_FILE.