1ytic / warp-rnnt

CUDA-Warp RNN-Transducer
MIT License
211 stars 41 forks source link

improve efficiency of warps #38

Open maxwellzh opened 1 year ago

maxwellzh commented 1 year ago

In current implementation, the warps along T axis are computed in fully serialized manner https://github.com/1ytic/warp-rnnt/blob/edd5857cd9abf29f12ab3fbc153f78f21191d80b/core.cu#L112-L134

The for loop of each warp is executed one-by-one, which means the ith warp at specific row u, has to wait for all its leading warps to finish the loops, and that is i (num of warps) * W (for loop overhead, warpsize, 32 here) time complexity.

However, we don't necessarily have to wait for previous warps to finish before we go into the loop in current warp.

Let's take forward computation of alphas as the example with warpsize=4: warp_sample Here d denotes the index inside a warp, so 0 <= d < W. B is the result from u-1 row and supposed to be ready.

The forward computation of alpha follows (indeed we do the computation in logarithm, here is just for discussion):

Screenshot 2022-12-30 at 15 14 13

Note that alpha_0 relies on result from the last warp.

Here comes the trick, I rewrote alpha_3 formula to following

Screenshot 2022-12-30 at 15 26 29

The underlined part is warp-independent. The first part (the product of emitting probability e_2 e_1 e_0) can be computed by prefix sum (scan) algorithm in logarithm, and only introduce log2(W) complexity.

Finally, the new procedure is like:

  1. Compute local paths combination prob (the underlined part). O(W) complexity;
  2. Compute product of emitting probs (e2e1e0, ...) with prefix sum algorithm. O(log2(W)) complexity;
  3. Wait for previous warps to finish and compute final results. Constant complexity.

For all warps at row u, 1 & 2 can be done in parallel, ith warp has only to wait all previous warps to finish step 3. The new procedure should be considerably faster than current serialized execution, especially when T is large.

1ytic commented 1 year ago

Hello Huahuan Zheng, interesting theory! But I don't think it will be useful in practice. Optimising a forward pass doesn't make sense. Your can check the cuda profiler logs. The big issue is memory IO, and I really like your previous MR with compact memory version. I wish to finish reviewing it and reopen your MR in near feature.

maxwellzh commented 1 year ago

Will do further investigation later :)

As for the IO issue, I remember I have seen in somewhere that a thread block would instinctively load nearby memory whatever it is used or not. Have you ever tried using (N, U, T, V) layout instead of (N, T, U, V)? With the former's (and especially when gather=True), a warp (also a thread block) is able to load a chunk of consecutive memory and reuse it.

Indeed, I've been using the compact version loss function in our speech recognition tasks for a while. It should be technically correct (it's in my dev branch now, the main branch hasn't been updated for some time). I'll finish some merge from my dev to the main branch, and once it's finished, I would reopen the MR.

1ytic commented 1 year ago

I’m not familiar with memory manager for cuda threads. But you right, having TxU matrix is the main bottleneck. Fortunately, there is solution for this, fast_rnnt. It looks really promising.

maxwellzh commented 1 year ago

I've been following the fast_rnnt work for a while, but haven't make a successful pruned rnn-t training yet.

They also have a paper about the implementation. https://arxiv.org/pdf/2206.13236.pdf