After having a look at nsys output, I think we are largely limited by latency of sequential operations in IntersectDevice, IntersectDense, GetForwardScores and GetBackwardScores (and of memory transfer when we invoke Array1::Back()).
I think there are two ways we can significantly reduce the time taken:
We can let the num and den FSAs be processed together by concatenating together the FsaVecs and calling IntersectDevice() just once, getting the tot_scores just once, and then post-processing ranges of the tot_scores.
IntersectDevice() is called when forming minibatches (intersecting with L and then with ctc_topo). If we can somehow arrange to batch these up it would be more efficient. It might not be super convenient code-wise, though.
After having a look at nsys output, I think we are largely limited by latency of sequential operations in IntersectDevice, IntersectDense, GetForwardScores and GetBackwardScores (and of memory transfer when we invoke Array1::Back()). I think there are two ways we can significantly reduce the time taken:
We can let the num and den FSAs be processed together by concatenating together the FsaVecs and calling IntersectDevice() just once, getting the tot_scores just once, and then post-processing ranges of the tot_scores.
IntersectDevice() is called when forming minibatches (intersecting with L and then with ctc_topo). If we can somehow arrange to batch these up it would be more efficient. It might not be super convenient code-wise, though.