iamjanvijay / rnnt

An implementation of RNN-Transducer loss in TF-2.0.
MIT License
45 stars 9 forks source link

This multiply operator might fail. #1

Closed digital10111 closed 4 years ago

digital10111 commented 4 years ago

https://github.com/mejanvijay/tensorflow_rnnt/blob/e18f10d82c8b0b815b80094dae5777aeae257e1b/rnnt_loss.py#L30

This multiply operator would fail, when input_max_len != (target_max_len-1).

Basically labels isbatch x (target_max_len-1). When converted to one_hot_labels it becomes batch x (target_max_len-1) x (target_max_len-1) x vocab_size.

logits is batch x input_max_len x target_max_len x vocab_size.

And when we do tf.multiply(log_probs[:, :, :-1, :], one_hot_labels). if input_max_len != (target_max_len-1) it should fail.

Our test cases are succeeding only cause input_max_len == (target_max_len-1) in all test cases. ie input_max_len = 5 and target_max_len = 6.

digital10111 commented 4 years ago

I also didn't quite understand what do the values in truth_probs represent?

iamjanvijay commented 4 years ago

It won't fail. one_hot_labels is actually of shape - batch x input_max_len x (target_max_len-1) x vocab_size

iamjanvijay commented 4 years ago

truth_probs - log-probabilities corresponding to emission of the correct label for the next decoding time step.