Closed zh794390558 closed 2 years ago
why grad formula in cpu is not equal to GPU? why grad = math.exp(alphas[col] + betas[col] + logpk - logll[mb])
not in cpu implementation?
LlForward is Eqn 10, not Eqn 9. So you can group the terms to become (1+lambda).
On CPU, we don't have cuda inplace, memory efficient logsoftmax operation, so it's grad is different as compared to cuda kernel which takes into account the internal logsoftmax calculation.
@titu1994 Thx for your reply. I have some problems to understand the grad computation of the transducer, please help more.
LlForward is Eqn 10, not Eqn 9. So you can group the terms to become (1+lambda).
I think llforward has blank label prob, but fastemit only tune predict label prob, so can't using (1+lambda) directly.
On CPU, we don't have cuda inplace, memory efficient logsoftmax operation, so it's grad is different as compared to cuda kernel which takes into account the internal logsoftmax calculation.
I'm confused with this, since torch.nn.functional.log_softmax
will not added in backward graph?
Can you give some explain for takes into account the internal logsoftmax calculation
and how?
I think llforward has blank label prob, but fastemit only tune predict label prob, so can't using (1+lambda) directly.
If you look at the grad code for numpy with fastemit, the final T, U step actually is exactly same for with and without fastemit (alpha beta has to be 1 at the T, U index). So if doesn't matter, and can be calculated with alphabeta. But we can compute alpha*beta without actually computing beta by just doing inference step calculation of LlForward, and then scaling it by lambda. I've confirmed it with the authors a year ago.
For cpu, the code explicitly calls log_softmax in pytorch, and backprop will take care of the gradient if logsoftmax so you only need to implement the RNNT loss grad. But that is inefficient cause you need a lot of extra memory for it.
On cuda, you can efficiently compute logsoftmax (much less memory), but now pytorch autograph will no longer compute the backprop of logsoftmax for you do you need to do it manually.
@titu1994 Thank you very much for the detailed explanation.
https://github.com/titu1994/warprnnt_numba/blob/b1bc81e02dfb05143c3d55ac7b50c8131e85b194/warprnnt_numba/rnnt_loss/utils/cpu_utils/cpu_rnnt.py#L232