pytorch / audio

Data manipulation and transformation for audio signal processing, powered by PyTorch
https://pytorch.org/audio
BSD 2-Clause "Simplified" License
2.44k stars 642 forks source link

Proposal for the integration of Tree-constrained Pointer Generator and Minimum Biasing Word Error (MBWE) training for contextual ASR #2483

Open BriansIDP opened 2 years ago

BriansIDP commented 2 years ago

🚀 The feature

I’d like to propose the integration of tree-constrained pointer generator (TCPGen) [1] and Minimum Biasing Word Error (MBWE) training [2] for contextual biasing into torchaudio package.

The implementation of TCPGen consists of 2 parts. The first part organises the biasing list and constructs a wordpiece-level prefix tree. The second part performs the prefix-tree search and calculates relevant pointer generator distributions in RNNT. During training, at the beginning of each minibatch, a biasing list will be organised by finding biasing words in the current utterance and adding a certain number of distractors. This biasing list is then converted into a prefix tree. Then, at each RNNT predictor step, a prefix-tree search is performed according to the wordpiece history, and a set of valid wordpieces are selected. For each combination of the encoder and predictor steps in RNNT, in addition to the model output distribution, we calculate another distribution over these valid wordpieces (TCPGen distribution). An interpolation between the model output distribution and this TCPGen distribution is then performed to achieve contextual biasing. The interpolation weight is also predicted by the TCPGen indicating how much contextual biasing is needed at each step. Inference follows the procedure as training.

For implementation, I would put the first part (organising prefix-trees) under some supporting function directories, such as utils/. I would put the second part (tree search and calculating relevant distributions) in the RNNT model implementation, before the joint network. It takes the combination of predictor and encoder output vectors and predicts a distribution for each of the combinations, which is of the same size as the logits output. The output vector of TCPGen, which is a (weighted) sum of wordpiece embeddings, can be also fed into the joint network to achieve deep biasing.

For MBWE training, we need to perform a batched beam search algorithm to get n-best hypotheses for each utterance in the minibatch, and then calculate the expected word error (WE) and biasing word error (BWE) for each utterance. For the batched beam search, we adopted a one-step constrained beam search which constrained the output at each encoder step can only be 0 or 1 wordpiece token hence all forward computation could be parallelised. Once the n-best hypotheses are obtained, the WE and the BWE will be calculated. The BWE is calculated as the edit-distance between the sequence of biasing words that appeared in the hypothesis and reference. The WE and BWE for each hypothesis are summed up and then multiplied with the normalised probabilities to get the final MBWE loss to be optimised. The batched beam search could potentially go into the rnnt_decoder.py file and the calculation of MBWE loss could either be directly implemented as a class method of the RNNT model or be separately put into some supporting function directories.

There are mainly 2 issues with the implementation. First, the interpolation is between 2 distributions, but the RNNT loss function takes logits as arguments, it is necessary to modify the RNNT loss to include such interpolation. The second issue is that in our experiments, wordpiece models are suffix-based models that mark word boundary at the end of a word. The benefit is that it is clear when to get back to the root of the tree, whereas, for the prefix-based one, any starting wordpieces could be valid which influences the effect of both our method and deep biasing.

Please refer to the following papers for more details.

[1] Guangzhi Sun, Chao Zhang, Philip C Woodland. Tree-constrained Pointer Generator for End-to-end Contextual Speech Recognition. In Proc. ASRU. 2021. Link: https://arxiv.org/abs/2109.00627.

[2] Guangzhi Sun, Chao Zhang, Philip C Woodland. Minimising Biasing Word Errors for Contextual ASR with the Tree-Constrained Pointer Generator. Link: https://arxiv.org/pdf/2205.09058.pdf.

Motivation, pitch

Contextual biasing, which integrates contextual knowledge into an ASR system, has become increasingly important to many applications. The tree-constrained pointer generator (TCPGen) is an effective contextual biasing component that combines the benefit of deep biasing and shallow fusion-based contextual biasing. The biasing list which contains a list of words or entities that are likely to appear in a given context could be customised depending on the use case. In addition, the minimum biasing word error (MBWE) training provides an effective way of training contextual ASR systems by particularly emphasising the biasing word errors.

Alternatives

No response

Additional context

No response

mthrok commented 2 years ago

Can @xiaohui-zhang comment on this?

xiaohui-zhang commented 2 years ago

@BriansIDP as we discussed offline, our team welcomes this feature (as a prototype feature first). So feel free to proceed with PRs at your convenience.

jtrmal commented 1 year ago

Hi guys, just pinging back, would there be any progress? I can help with implementing/testing

xiaohui-zhang commented 1 year ago

@BriansIDP can you share you progress here? Thanks!

BriansIDP commented 1 year ago

Hi @jtrmal and @xiaohui-zhang. I have implemented most parts of the code in Torchaudio. However, as the current RNNT loss only accepts logits rather than log probabilities, and TCPGen requires probability interpolation, a version of RNNT loss that accepts log probabilities as input (or just probabilities, so bypassing softmax in RNNT loss) is necessary. I think this loss is under implementation and I expect it to finish in November or December. Once it is finished, I will finish my code and examples as soon as possible. I hope this helps.

xiaohui-zhang commented 1 year ago

the "log-prob as input" is being addressed by @carolineechen. @carolineechen if possible, when you start working on it, please give a more accurate ETA here. Thanks!

desh2608 commented 1 year ago

I am not super familiar with TCPGen, but my understanding was that directly using logits in the RNN-T loss (with "function merging") saves a lot of memory since you don't have to store 3 large tensors for backprop. I believe most implementations pass logits and modify the RNN-T formulation accordingly for this reason. Would going back to probs (or log-probs) mean giving up on these memory savings?

carolineechen commented 1 year ago

@BriansIDP @xiaohui-zhang just and update that log probs support is enabled in #2798; it should be available in the nightly/source builds of torchaudio.

@desh2608 this is a good point -- the original implementation directly uses logits with function merging as mentioned in this paper which helps to reduce memory. This option is still the default and its implementation is unchanged, we are simply expanding the capability to support both either input type (logits, softmax). I am also not very familiar with TCPGen, @BriansIDP not too sure if function merging is something that could be compatible with TCPGen

BriansIDP commented 1 year ago

Thank you very much @carolineechen for addressing this quickly! Thank you @desh2608 for pointing this out. At the moment TCPGen has to do with probabilities rather than logits, but I will work on this in the future to see if it is possible to take function merging into consideration.