k2-fsa / k2

FSA/FST algorithms, differentiable, with PyTorch compatibility.
https://k2-fsa.github.io/k2
Apache License 2.0
1.11k stars 213 forks source link

Penalizing and measuring symbol delay #955

Open danpovey opened 2 years ago

danpovey commented 2 years ago

I'd like to have a way to penalize symbol delay in RNN-T computations, and to measure it during training. (This is relevant for systems that use time masking, to avoid the time-masking encouraging the network to delay symbols).

My proposal is this: we introduce a user-specifiable constant \alpha >= 0that says how much we'll penalize symbol delay. Suppose we have a sequence of lengths S_n, T_n, (n is the batch index), i.e. there are S_n symbols in the sequence and T_n frames; px[n,s,t] will contain the log-probability of emitting a non-blank symbol starting from position s,t of sequence n.

To penalize symbol delay, we'll do: px[n,s,t] += offset_n - alpha t. [We could equivalently do this on the py's]. the "- alpha t" term will encourage emitting symbols earlier (because we reduce the log-probs of later time-steps). offset_n is a constant per sequence that is not going to affect where symbols are emitted but which keeps the loss function at about the old value; we will set offset_n = alpha (T_n-1) / 2, which is the average of [0,1,..T_n-1] which are all the possible positions a symbol could be emitted. We'll have to do this in at least 2 different places: once for the "simple" loss and once for the "pruned" loss.

To measure symbol delay (for diagnostic purposes), I propose to measure it only from the "simple" model, because we compute those grads anyway during the forward pass. I would like to define "normalized symbol delay" as the average (over symbols) of: (position symbol is emitted, in [0,1,..T_n-1]) - (T_n-1), i.e. it's positive if symbols tend to be emitted later than the middle of the audio, on average; negative if before. This should be computable from the sum or average of px_grad * arange(T+1), since px_grad[n,s,t] is the probability of the symbol at position s being emitted on frame t. We can average and report this "sym_delay" with our other stats, if alpha is specified.

This "sym_delay" does not directly reflect the real delay because it could be that our corpus has more trailing than preceding silence; but it's valid for comparisons, and we can always compute this delay from a reference system that was trained with alpha=0 and no time masking, which should reflect the actual time symbols were spoken (since there is a time-symmetry which should prevent symbols from being emitted, on average, earlier or later than the true time).

pkufool commented 2 years ago

I will try this idea with this pull request(https://github.com/k2-fsa/icefall/pull/242).

desh2608 commented 2 years ago

BTW FastEmit is a commonly-used regularization technique to reduce emission latency. It simply boosts the log-prob on non-blank tokens at each time frame.

danpovey commented 2 years ago

I suspect that what I described is equivalent to whatever it was they implemented. But that paper is very unclear, I can't make sense of it.

I had another look at the FastEmit paper. Towards the end of the paper they describe scaling up the grads for the non-blank labels by (1+lambda), leaving blank grads unscaled, it seems that this is the only change needed in practice. I looked at the implementation in warp-rnnt code and, there, this is not done inside the recursion, it is done after the recursion. It seems to me that scaling the grads this way (outside the recursion) would not affect the alignments at all. If you think about it at the whole-path level, each path has the same number of blanks and non-blanks, so "earlier" paths will be affected by higher non-blank probs in the same way as "later" paths; it would just be similar to decoding with a blank penalty.

I had a look at the papers citing FastEmit and, while others at Google seem to be using it successfully, the Facebook guys https://arxiv.org/pdf/2104.02207.pdf didn't find that it improved endpoint lag. I suspect that they have not described accurately what they were doing.

desh2608 commented 2 years ago

I was looking at the FastEmit implementation in NeMo's transducer loss (which is written in Numba), and it seems they do the regularization "inside" the recursion here.

danpovey commented 2 years ago

I don't think that's inside the recursion. The compute_alphas() and compute_betas() have loops; compute_grads() has no loop except over the vocabulary size.