1ytic / warp-rnnt

CUDA-Warp RNN-Transducer
MIT License
211 stars 41 forks source link

[Add] Support gather log_probs inside the loss function. #16

Closed TeaPoly closed 3 years ago

TeaPoly commented 3 years ago

I had implemented gather log_probs operation inside the warprnnt loss function. But it is really too complicated now. Because TensorFlow does not have any function like torch.gather, and this part will continue to be optimized in the future.

TeaPoly commented 3 years ago

Is it possible to implement the function of gather log_probs [B,T,U,V],blank and labels [B, U-1] inside gather_log_probs [B,T,U,2] with the CUDA C++ program?

1ytic commented 3 years ago

@TeaPoly thank you! Sorry for long delay. gather_log_probs looks amazing, and I think you have to use TF primitives which automatically support derivatives functions for training purpose.