csukuangfj / transducer-loss-benchmarking

Other
65 stars 10 forks source link

The output of utils.generate_data() is logprobs or probs? #15

Closed season95 closed 1 year ago

season95 commented 1 year ago

In benchmark_k2_pruned.py, line 112, k2.rnnt_loss_smoothed() takes decoder_out and encoder_out as input, which are generated by utils.generate_data(). While the definition of decoder_out or encoder_out confused me. In detail, on one hand, utils.generate_data() generated encoder_out by "torch.rand()", whose elements are all in [0, 1) and seems like 'un-normalized probs'. On the other hand, as was described in k2.rnnt_loss_smoothed(), the input 'encoder_out' is supposed to be the 'un-normalized logprobs'.

so, decoder_out and encoder_out are 'probs' or 'logprobs'?

csukuangfj commented 1 year ago

Neither.

They are the output of nn.Linear.

season95 commented 1 year ago

Neither.

They are the output of nn.Linear.

Thanks for replying! So that is to say, decoder_out and encoder_out are just logits, and we treat then as logprobs in rnnt_loss_smoothed() ?

csukuangfj commented 1 year ago

Is logits also known as unnormalized probs?

season95 commented 1 year ago

Is logits also known as unnormalized probs?

Maybe it is. After reading your paper of prune-rnnt and the code, it seems that the 'un-normalized logprobs' in your paper might be the 'logits' i mentioned above, which is usually the output of FC and the input of SoftMax or Sigmoid. I just want to check it with you.