From a suggestion proposed by PaddlePaddle community, it is recommended to add use_softmax and zero_infinity options to warpctc, comparing to the Pytorch API of torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean', zero_infinity=False).
I made an attempt to add this two options in this PR, and the logic is shown the following:
use_softmax: The attribute of log_probs in Pytorch receives the output of logsoftmax operation. Therefore, it is suggested when the users want to use the output from logsoftmax operation as an input, we could omit the softmax operation in our code. The proposed improvement is when use_softmax=False, perform exponential operation on the input logits.
zero_infinity: It is hoped to zero the infinity cost and associated gradient when zero_infinity=True. So first to make infinity value of cost available, I omit the truncation process done to the logits. And then I do zero operation after the calculation of cost and grad for each batch of the input.
I also add the corresponding test for this two cases. (PS. The test of inf_test seems only to pass when the truncation processes are omited.)
From a suggestion proposed by PaddlePaddle community, it is recommended to add
use_softmax
andzero_infinity
options to warpctc, comparing to the Pytorch API oftorch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean', zero_infinity=False)
.I made an attempt to add this two options in this PR, and the logic is shown the following:
use_softmax
: The attribute oflog_probs
in Pytorch receives the output of logsoftmax operation. Therefore, it is suggested when the users want to use the output from logsoftmax operation as an input, we could omit the softmax operation in our code. The proposed improvement is whenuse_softmax=False
, perform exponential operation on the input logits.zero_infinity
: It is hoped to zero the infinity cost and associated gradient whenzero_infinity=True
. So first to make infinity value of cost available, I omit the truncation process done to the logits. And then I do zero operation after the calculation of cost and grad for each batch of the input.I also add the corresponding test for this two cases. (PS. The test of
inf_test
seems only to pass when the truncation processes are omited.)