HawkAaron / warp-transducer

A fast parallel implementation of RNN Transducer.
Apache License 2.0
306 stars 124 forks source link

CPU gradient definitions #69

Open xkr1 opened 4 years ago

xkr1 commented 4 years ago

Hi,

Firstly, thanks so much for making this code available. It's a really useful resource for learning more about RNN-T.

I have two questions, if I may:

i.) in cpu_rnnt.h the gradients are defined as:

ProbT g = alphas[idx(t, u)] + betas[idx(t+1, u)];
grad[idx(t, u, blank_)] = -std::exp(log_probs[idx(t, u) * 2] + g - loglike);

and

ProbT g = alphas[idx(t, u)] + betas[idx(t, u+1)];
grad[idx(t, u, labels[u])] = -std::exp(log_probs[idx(t, u) * 2 + 1] + g - loglike); 

Following rnnt_notes.pdf, as far as I understand the formula used here is equation 9. From this, I can account for the following part of the above code:

ProbT g = alphas[idx(t, u)] + betas[idx(t+1, u)];
grad[idx(t, u, blank_)] = -std::exp(g - loglike);

but I'm not sure where the log_probs[] comes in here?

ii.) Following on from my first question, it might be obvious that I'm not an expert in calculus and so I was wondering if you'd be able to point me to any resources or give any hints so that I can learn more myself about how to derive the gradient of the loss from eq. 5, 7 and 8 in rnnt_loss.pdf? Even the smallest hint would be greatly appreciated!

Thanks again!

HawkAaron commented 4 years ago

Well, the gradidents here are to log probability, \frac{\partial L}{\partial \log p} = \frac{\partial L}{\partial p} \frac{\partial \exp{\log p}}{\partial \log p} = p \frac{\partial L}{\partial p}.

xkr1 commented 4 years ago

Thank you! That makes a lot more sense to me now :-)

Referring now to the component g - loglike (in log domain): I understand how g is derived but it's so far been a mystery to me why the -loglike is there.

Would I be right in saying that this is because the loss, L, is defined in log domain, i.e. L = -lnP(y|x) and so we end up dividing through by this P(y|x) term (loglike) as a result of applying the chain rule to take the derivative of the log?

q121q commented 4 years ago

how is this different than CTC loss?