propose SeaRNN, a novel training algorithm for RNN inspired by the "learning to search" (L2S) approach to structured prediction
Performance improves on OCR, Spelling Correction and NMT
Details
Motivation
standard training algorithm for RNNs is via MLE (Maximum Likelihood Estimation), where each conditional probability is optimized via cross entropy
weakness 1) 1/0 : it makes no distinction between candidates that are close or far away from the ground truth. Only treats single token as a correct answer
weakness 2) exposure bias : training takes place using ground truth tokens as previous target inputs, whereas in test case, model's previous outputs are used as target inputs, hence model may run into new hypothesis space that it has never been during training phase
weakness 3) locality : RNN is trained to predict next token conditioned on all previous predictions, which is very local in nature. Since MLE loss of a sequence is equal to the product of each token's probability, it is sequence level loss in theory, but if, in practice, only last few predictions are conditioned, it is local
SeaRNN Algorithm
Roll-in : decode in greedy manner up to T.
Roll-out L for each t in T, Roll-out expand to all possible token choices and decode up to T
For each roll-out t, compute cost-sensitive loss L(c) with cost function c
Questions
Note that we do not need the test error to be differentiable, as our costs ct(a) are fixed when we minimize our training loss. : how is it possible that test error not be differentiable?
Experiment
Neural Machine Translation
IWSLT14 EnDe word
better result in conv, dropout
Personal Thoughts
need to re-read to fully understand
need to read L2S
it's interesting to see a different training approach from MLE and RL
does it really solve exposure bias problems? Will the amount of data resolve it? Will it outperform MLE training for large-scale NMT datasets?
Abstract
SeaRNN
, a novel training algorithm for RNN inspired by the "learning to search" (L2S) approach to structured predictionDetails
Motivation
1/0
: it makes no distinction between candidates that are close or far away from the ground truth. Only treats single token as a correct answerexposure bias
: training takes place using ground truth tokens as previous target inputs, whereas in test case, model's previous outputs are used as target inputs, hence model may run into new hypothesis space that it has never been during training phaselocality
: RNN is trained to predict next token conditioned on all previous predictions, which is very local in nature. Since MLE loss of a sequence is equal to the product of each token's probability, it issequence level
loss in theory, but if, in practice, only last few predictions are conditioned, it islocal
SeaRNN Algorithm
Roll-in : decode in greedy manner up to
T
.Roll-out L for each
t
inT
, Roll-out expand to all possible token choices and decode up toT
For each roll-out
t
, compute cost-sensitive lossL(c)
with cost functionc
Questions
Note that we do not need the test error to be differentiable, as our costs ct(a) are fixed when we minimize our training loss.
: how is it possible that test error not be differentiable?Experiment
conv
,dropout
Personal Thoughts
Link : https://arxiv.org/pdf/1706.04499.pdf Authors : Leblond et al. 2018