ray075hl / attention-ocr-toy-example

Apache License 2.0
31 stars 9 forks source link

attention是不是对sequence长度比较敏感? #5

Closed fendaq closed 5 years ago

fendaq commented 6 years ago

楼主你好,我测试了一下您提供的代码,在3位数图像精度比较高,但是4位数和5位数就比较差了。 这个问题有什么好的解决办法吗?

ray075hl commented 6 years ago

@fendaq 我在做中文识别的时候也发现了这个问题 , 目前也在找原因。 说一下我的一点目前的实验经验吧 1.首先传入attetnion decoder的initial state最好用双向rnn编码出来的state,而不要用全零的向量,后者容易导致第一次解码的结果错误 2.我觉得attetnion似乎对噪声(模糊)比较敏感,经常出现定位不准,导致遗漏,而一起训练的ctc就没有这个问题

至于你说的长度, 我发现清晰的图片一行十几个字也是可以完全识别正确的,所以我认为是噪声敏感,个人见解,经供参考

fendaq commented 6 years ago

1.首先传入attetnion decoder的initial state最好用双向rnn编码出来的state,而不要用全零的向量,后者容易导致第一次解码的结果错误

 enc_outputs, encoder_state= tf.nn.bidirectional_dynamic_rnn(cell_fw=cell,
                                                             cell_bw=cell,
                                                             inputs=cnn_out,
                                                             dtype=tf.float32)
...............................
    initial_state = attn_cell.zero_state(BATCH_SIZE, tf.float32).clone(cell_state=encoder_state)

    decoder = tf.contrib.seq2seq.BasicDecoder(
        cell= attn_cell, helper=helper,
        initial_state = initial_state,
        output_layer=output_layer)

是这样修改吗?为什么我这边会报错,没法运行。。

Traceback (most recent call last): File "/data/attention-ocr-toy-example/attention_model.py", line 195, in main() File "/data/attention-ocr-toy-example/attention_model.py", line 190, in main loss, train_one_step, train_decode_result, pred_decode_result = build_compute_graph() File "/data/attention-ocr-toy-example/attention_model.py", line 115, in build_compute_graph train_outputs = decode(train_helper, train_output_embed,enc_state, 'decode') File "/data/attention-ocr-toy-example/attention_model.py", line 99, in decode impute_finished=True, maximum_iterations=MAXIMUMDECODE_ITERATIONS) File "/home/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/seq2seq/python/ops/decoder.py", line 286, in dynamic_decode swap_memory=swap_memory) File "/home/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2816, in while_loop result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants) File "/home/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2640, in BuildLoop pred, body, original_loop_vars, loop_vars, shape_invariants) File "/home/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2590, in _BuildLoop body_result = body(*packed_vars_for_body) File "/home/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/seq2seq/python/ops/decoder.py", line 234, in body decoder_finished) = decoder.step(time, inputs, state) File "/home/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py", line 138, in step cell_outputs, cell_state = self._cell(inputs, state) File "/home/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/rnn_cell_impl.py", line 183, in call return super(RNNCell, self).call(inputs, state) File "/home/anaconda3/lib/python3.6/site-packages/tensorflow/python/layers/base.py", line 575, in call outputs = self.call(inputs, *args, *kwargs) File "/home/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py", line 1295, in call cell_output, next_cell_state = self._cell(cell_inputs, cell_state) File "/home/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/rnn_cell_impl.py", line 183, in call return super(RNNCell, self).call(inputs, state) File "/home/anaconda3/lib/python3.6/site-packages/tensorflow/python/layers/base.py", line 575, in call outputs = self.call(inputs, args, **kwargs) File "/home/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/rnn_cell_impl.py", line 320, in call kernel_initializer=self._kernel_initializer) File "/home/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/rnn_cell_impl.py", line 1154, in init__ shapes = [a.get_shape() for a in args] File "/home/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/rnn_cell_impl.py", line 1154, in shapes = [a.get_shape() for a in args] AttributeError: 'tuple' object has no attribute 'get_shape'

ray075hl commented 6 years ago

@fendaq clone(cell_state=encoder_state[0])

enc_outputs, enc_state = tf.nn.bidirectional_dynamic_rnn(cell_fw=cell, 
                                                                                               cell_bw=cell,
                                                                                               inputs=cnn_out,
                                                                                               dtype=tf.float32)

decoder = tf.contrib.seq2seq.BasicDecoder(cell=attn_cell, helper=helper,
                  initial_state=
                  attn_cell.zero_state(dtype=tf.float32,batch_size=batch_size).clone(cell_state=enc_state[0]),
                  output_layer=output_layer)
ray075hl commented 6 years ago

此外 訓練的時候用

train_helper = tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper(output_embed,att_train_length,
embeddings,sample_rate)

函數 具體可以參考issue #3

fendaq commented 6 years ago

nmt并没有使用ScheduledEmbeddingTrainingHelper,但是可以正确预测,所以我觉得ScheduledEmbeddingTrainingHelper是非必须的。