Closed phybrain closed 5 years ago
你指的是输入图片的长度还是 label的长度? 效果是针对中文吗? 中文6000类的话,效果和纯ctc比要弱一些,特别是在复杂场景下。
label长度,加上attention效果还要弱啊···············
label长度 在一个batch里是固定的 在不同的batch里 是变化的
"加上attention效果还要弱啊" 这只是我得出的结论 不一定对!!!
怎么改成不固定长度啊 ,求指教
@phybrain
感受一下 tf.sequence_mask 这个函数
假设batch_size = 3, 第一个样本的label长度为1, 第二个为3, 第三个为2, 最大label长度为5,
则:
tf.sequence_mask([1, 3, 2], 5) 生成如下的mask
[[True, False, False, False, False],
[True, True, True, False, False],
[True, True, False, False, False]]
attention_model.py的122-124行
mask = tf.cast(tf.sequence_mask(BATCH_SIZE * [train_length[0] - 1], train_length[0]), tf.float32)
att_loss = tf.contrib.seq2seq.sequence_loss(train_outputs[0].rnn_output, target_output,
weights=mask)
计算loss的时候,考虑mask即可
你指的是输入图片的长度还是 label的长度?