awni / transducer

A Fast Sequence Transducer Implementation with PyTorch Bindings
Apache License 2.0
196 stars 36 forks source link

transducer

A fast RNN-Transducer implementation on the CPU and GPU (CUDA) with python bindings and a PyTorch extension. The RNN-T loss function was published in Sequence Transduction with Recurrent Neural Networks.

The code has been tested with Python 3.9 and PyTorch 1.9.

Install and Test

To install from the top level of the repo run:

python setup.py install

To use the PyTorch extension, install PyTorch and test with:

python torch_test.py

Usage

The easiest way to use the transducer loss is with the PyTorch bindings:

criterion = transducer.TransducerLoss()
loss = criterion(emissions, predictions, labels, input_lengths, label_lengths)

The loss will run on the same device as the input tensors. For more information, see the criterion documentation.

To get the "teacher forced" best path:

predicted_labels = criterion.viterbi(emissions, predictions, input_lengths, label_lengths)

Memory Use and Benchmarks

The transducer is designed to be much lighter in memory use. Most implementations use memory which scales with the product B * T * U * V (where B is the batch size, T is the maximum input length in the batch, U is the maximum output length in the batch, and V is the token set size). The memory of this implementation scales with the product B * T * U and does not increase with the token set size. This is particularly important for the large token set sizes commonly used with word pieces. (NB In this implementation you cannot use a "joiner" network to connect the outputs of the transcription and prediction models. The algorithm hardcodes the fact that these are additively combined.)

Performance benchmarks for the CUDA version running on an A100 GPU are below. We compare to the Torch Audio RNN-T loss which was also run on the same A100 GPU. An entry of "OOM" means the implementation ran out of memory (in this case 20GB).

Times are reported in milliseconds.

T=2000, U=100, B=8

V Transducer Torch Audio
100 8.18 139.26
1000 13.64 OOM
2000 18.83 OOM
10000 59.18 OOM

T=2000, U=100, B=32

V Transducer Torch Audio
100 20.58 555.00
1000 38.42 OOM
2000 58.19 OOM
10000 223.33 OOM