ssampang / im2latex

Tensorflow port of im2markup project
128 stars 23 forks source link

Multi layer decoder issue #6

Open longzaitianguo opened 6 years ago

longzaitianguo commented 6 years ago

@ssampang Hi, you use a one layer lstm decoder in embedding_attention_decoder. To improve the performance, I want to implement a multi layer lstm decoder. However, when I send the multi layer lstm to embedding_attention_decoder, I got the following error:

Traceback (most recent call last): File "im2latex.py", line 295, in debug_output, (output, state) = build_model(inp, batch_size, num_rows, num_columns, num_words) File "im2latex.py", line 264, in build_model feed_previous=True) File "/home/guolong/math_ocr/RFR-solution/3_im2latex/decoder.py", line 177, in embedding_attention_decoder loop_function=loop_function) File "/home/guolong/math_ocr/RFR-solution/3_im2latex/decoder.py", line 138, in attention_decoder outputs = tf.while_loop(cond, body, loop_vars, shape_invariants) File "/home/guolong/venv-latex/local/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2636, in while_loop result = context.BuildLoop(cond, body, loop_vars, shape_invariants) File "/home/guolong/venv-latex/local/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2469, in BuildLoop pred, body, original_loop_vars, loop_vars, shape_invariants) File "/home/guolong/venv-latex/local/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2419, in _BuildLoop body_result = body(*packed_vars_for_body) File "/home/guolong/math_ocr/RFR-solution/3_im2latex/decoder.py", line 104, in body cell_output, state = cell(x, state) File "/home/guolong/venv-latex/local/lib/python2.7/site-packages/tensorflow/python/ops/rnn_cell.py", line 815, in call cur_inp, new_state = cell(cur_inp, cur_state) File "/home/guolong/venv-latex/local/lib/python2.7/site-packages/tensorflow/python/ops/rnn_cell.py", line 308, in call c, h = state File "/home/guolong/venv-latex/local/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 510, in iter raise TypeError("'Tensor' object is not iterable.") TypeError: 'Tensor' object is not iterable.

Here is the code I wrote for the multi layer lstm: dec_lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(dec_lstm_dim, state_is_tuple=True) number_of_layers = 2 if number_of_layers > 1: dec_multi_lstm_cell = tf.nn.rnn_cell.MultiRNNCell([dec_lstm_cell] * number_of_layers, state_is_tuple =True) else: dec_multi_lstm_cell = dec_lstm_cell

decoder_output = decoder.embedding_attention_decoder(dec_init_state, \
                                                     tf.reshape(encoder_output, \
                                                                [batch_size, -1, \
                                                                 2 * enc_lstm_dim]), \
                                                     #dec_lstm_cell, \
                                                     **dec_multi_lstm_cell**, \
                                                     vocab_size, \
                                                     dec_seq_len, \
                                                     batch_size, \
                                                     embedding_size, \
                                                     feed_previous=True)

Looking forward to your reply.