1ytic / warp-rnnt

CUDA-Warp RNN-Transducer
MIT License
211 stars 41 forks source link

warp-rnnt with compact memory layout implementation #26

Closed maxwellzh closed 1 year ago

maxwellzh commented 2 years ago

The compact layout in memory can be explained with this figure. The input to rnnt_loss() is of size (N, T, U+1, V)=(3, 6, 6, V) in a normal layout. Colored boxes denote data and white boxes denote the padding. (Note that I eliminate the V dimension) compactlayout

I also implement the gather mode with compact layout, which is more recommended. Since it's difficult to select the indices (current gather mode uses the torch.gather) in a compact layout, I integrate the gather operation and its backward computation in C++/CUDA.

Only PyTorch binding is implemented.

maxwellzh commented 2 years ago

benchmark results: (tested on RTX 3090)

1ytic commented 2 years ago

Wow, looks cool! Thank you! I will check it in next few days. As far as I remember, there was a research paper about the compact layout. Maybe we should mention this in the README as well.

maxwellzh commented 2 years ago

Maybe this one? https://arxiv.org/abs/1909.12415

1ytic commented 2 years ago

Exactly, Section 3.1. Efficient encoder and prediction output combination is it similar or am I misunderstood?

maxwellzh commented 2 years ago

Yes, that's it. We can add it in the README later.

1ytic commented 2 years ago

I did review and I can't understand how it will be used in practice. @maxwellzh could you shed some light on how to prepare log_probs and labels arrays in practice? Is it possible to efficiently prepare a compact log_probs array? It looks like this array will always have the shape (N, T, U, V) before this cost function and we have to convert it manually with compactTensor function. Is it correct?

maxwellzh commented 2 years ago

The function compactTensor is just for testing. If we convert the tensor by that function every time invoking rnnt_loss(), the overall performance might be poor. Normally, in Joint Network, we do such things.

# (N, T, H_enc) -> (N, T, H_joint)
trans_enc = fc_enc(encoder_output) # linear layer over encoder, or the transcription network

# (N, U+1, H_dec) -> (N, U+1, H_joint)
trans_dec = fc_dec(decoder_output) # linear layer over decoder, or the prediction network

# (N, T, 1, H_joint) + (N, 1, U+1, H_joint) -> (N, T, U+1, H_joint)
expanded_input = trans_enc.unsqueeze(2) + trans_dec.unsqueeze(1) # broadcast to add up

# (N, T, U+1, H_joint) -> (N, T, U+1, V)
rnnt_input = classifier(sigmoid(expanded_input)).log_softmax(dim=-1)

loss = rnnt_loss(rnnt_input, ...)

With compact layout, in practice, I would recommend to compact the tensors before the first linear layer. Like this,

# (N, T, H_enc) -> (ST, H_enc), ST denotes \sum{Ti}
encoder_output_compact = compactTensor(encoder_output) # make it compact
# (ST, H_enc) -> (ST, H_joint)
trans_enc = fc_enc(encoder_output_compact) # linear layer over encoder, or the transcription network

# (N, U+1, H_dec) -> (SU, H_dec), SU denotes \sum{Ui+1}
decoder_output_compact = compactTensor(decoder_output) # make it compact
# (SU, H_dec) -> (SU, H_joint)
trans_dec = fc_dec(decoder_output_compact) # linear layer over decoder, or the prediction network

# broadcast addup is not as easy as the normal layout, so we have to add it by sequence or implement a faster CUDA-binding function
# (ST, H_joint) + (SU, H_joint) -> (STU, H_joint), STU denotes \sum{Ti(Ui+1)}
expanded_input = TrickyAdd(trans_enc, trans_dec)

# (STU, H_joint) -> (STU, V)
rnnt_input = classifier(sigmoid(expanded_input)).log_softmax(dim=-1)

loss = rnnt_loss(rnnt_input, ..., compact=True)

compactTensor() with 3-dim tensor is easy to do with torch.cat. As for the TrickyAdd(), I didn't think of any efficient way with python, so I just implement it with CUDA along with its backward. In fact, I also implement the compactTensor() for 3-dim tensors in CUDA. Do you have suggestions about this?

1ytic commented 2 years ago

Thanks for clarification! Now it make sense, you have the additional functions. Could you add these into MR? Without these functions, it's not clear how to use the compact version more efficiently than the original one.

maxwellzh commented 2 years ago

It's something outside the rnnt-loss, so I'd like to create a new repo including these functions implementation. It will take some time for me to prepare the codes. I'll let you know if it's prepared. BTW. the compact version is more efficient for squeezing the padding, so the performance improvement would also has relation to the N=batch_size. Theoretically, when N is large, we have lots of padding in the normal layout, so the compact version is expected to be better. However, if N is relatively small, considering the overhead of compactTensor, the compact version might be poor.

maxwellzh commented 2 years ago

I have released the implementation of these functions. https://github.com/maxwellzh/torch-gather

1ytic commented 2 years ago

Thank you! I will check it on the weekend.