iamjanvijay / rnnt

An implementation of RNN-Transducer loss in TF-2.0.
MIT License
45 stars 9 forks source link

what the use of pred_grads? #5

Closed wyattsuen closed 3 years ago

wyattsuen commented 4 years ago

Hi, I have studied the source code, but the matrix operations are used so weird that really confuse me. Maybe you can kindly give a good literature about the algorithm with such strange diagonal matrix for us to understand. Now come to the problem, in the example: pred_loss, pred_grads = loss_grad_gradtape(logits, labels, label_lengths, logit_lengths) Is the pred_loss for tensorflow model loss function? what the use of pred_grads?

And when I check the source code, find the loss loss = -final_state_probs and final_state_probs = beta[:, 0, 0]

the loss is get only from backward_dp() without connection with forward_dp(). So I think the pred_loss can't be used in tensorflow model simply. What's the correct training method for tensorflow, following is correct?

logits = some_deep_network(...)
pred_loss, pred_grads = loss_grad_gradtape(logits, labels, label_lengths, logit_lengths)
rnnt_model = tf.keras.Model(inputs=[logits, labels, label_lengths, logit_lengths], outputs=pred_loss)
rnnt_model.compile(optimizer='adam', loss=lambda y_true, y_pred: y_pred)
rnnt_model.fit(...)
iamjanvijay commented 3 years ago

Hi @wyattsuen!

Thanks for sharing your views. Sure I'll put up a wiki for this repo, to explain this diagonal based algorithm.

_Regarding no involvement of forwarddp() in loss computation: Forward and Backward loss is exactly the same if you assume that there are no precision errors. So either of them can be used in Loss computation. At the same time, it is mandatory to compute both, because the alpha (forward-DP) and beta (backward-DP) matrices are used in grad computation.

I'm not sure how to use it with Keras .fit() call. But here's a sample way to use: https://github.com/iamjanvijay/rnnt/blob/master/source/sample_train.py