Closed digital10111 closed 4 years ago
I also didn't quite understand what do the values in truth_probs
represent?
It won't fail. one_hot_labels is actually of shape - batch x input_max_len x (target_max_len-1) x vocab_size
truth_probs - log-probabilities corresponding to emission of the correct label for the next decoding time step.
https://github.com/mejanvijay/tensorflow_rnnt/blob/e18f10d82c8b0b815b80094dae5777aeae257e1b/rnnt_loss.py#L30
This multiply operator would fail, when
input_max_len != (target_max_len-1)
.Basically labels is
batch x (target_max_len-1)
. When converted to one_hot_labels it becomesbatch x (target_max_len-1) x (target_max_len-1) x vocab_size
.logits is
batch x input_max_len x target_max_len x vocab_size
.And when we do
tf.multiply(log_probs[:, :, :-1, :], one_hot_labels)
. ifinput_max_len != (target_max_len-1)
it should fail.Our test cases are succeeding only cause
input_max_len == (target_max_len-1)
in all test cases. ieinput_max_len = 5
andtarget_max_len = 6
.