Open maxwellzh opened 1 year ago
Hello Huahuan Zheng, interesting theory! But I don't think it will be useful in practice. Optimising a forward pass doesn't make sense. Your can check the cuda profiler logs. The big issue is memory IO, and I really like your previous MR with compact memory version. I wish to finish reviewing it and reopen your MR in near feature.
Will do further investigation later :)
As for the IO issue, I remember I have seen in somewhere that a thread block would instinctively load nearby memory whatever it is used or not. Have you ever tried using (N, U, T, V) layout instead of (N, T, U, V)? With the former's (and especially when gather=True), a warp (also a thread block) is able to load a chunk of consecutive memory and reuse it.
Indeed, I've been using the compact version loss function in our speech recognition tasks for a while. It should be technically correct (it's in my dev branch now, the main branch hasn't been updated for some time). I'll finish some merge from my dev to the main branch, and once it's finished, I would reopen the MR.
I’m not familiar with memory manager for cuda threads. But you right, having TxU matrix is the main bottleneck. Fortunately, there is solution for this, fast_rnnt. It looks really promising.
I've been following the fast_rnnt work for a while, but haven't make a successful pruned rnn-t training yet.
They also have a paper about the implementation. https://arxiv.org/pdf/2206.13236.pdf
In current implementation, the warps along T axis are computed in fully serialized manner https://github.com/1ytic/warp-rnnt/blob/edd5857cd9abf29f12ab3fbc153f78f21191d80b/core.cu#L112-L134
The for loop of each warp is executed one-by-one, which means the ith warp at specific row
u
, has to wait for all its leading warps to finish the loops, and that isi (num of warps) * W (for loop overhead, warpsize, 32 here)
time complexity.However, we don't necessarily have to wait for previous warps to finish before we go into the loop in current warp.
Let's take forward computation of alphas as the example with
warpsize=4
: Hered
denotes the index inside a warp, so0 <= d < W
.B
is the result fromu-1
row and supposed to be ready.The forward computation of alpha follows (indeed we do the computation in logarithm, here is just for discussion):
Note that
alpha_0
relies on result from the last warp.Here comes the trick, I rewrote
alpha_3
formula to followingThe underlined part is warp-independent. The first part (the product of emitting probability
e_2 e_1 e_0
) can be computed by prefix sum (scan) algorithm in logarithm, and only introduce log2(W) complexity.Finally, the new procedure is like:
For all warps at row u, 1 & 2 can be done in parallel, ith warp has only to wait all previous warps to finish step 3. The new procedure should be considerably faster than current serialized execution, especially when T is large.