k2-fsa / snowfall

Moved to https://github.com/k2-fsa/icefall
Apache License 2.0
143 stars 42 forks source link

Plan for multi pass n-best rescoring #232

Open danpovey opened 3 years ago

danpovey commented 3 years ago

[Guys, I have gym now so I'll submit this and write the rest of this later today. ]

I am creating an issue to describe a plan for multi-pass n-best-list rescoring. This will also require new code in k2, I'll create a separate issue. The scenario is that we have a CTC or LF-MMI model and we do the 1st decoding pass from that. Anything that we can do with lattices, we do first (e.g. including any FST-based LM rescoring). Let the possibly-LM-rescored lattice be the starting point for the n-best rescoring process.

The first step is to generate a long n-best list for each lattice by calling RandomPaths() with a largish number, like 1000. We then choose unique paths based on token sequences, where 'token' is whatever type of token we are using in the transformer and RNNLM-- probably word pieces. That is, we use inner_labels='tokens' when doing the composition with the CTC topo when making the decoding graph, and these get propagated to the lattices, so we can use lats.tokens and remove epsilons and pick the unique paths.

I think we could have a data structure called Nbest-- we could draft this in snowfall for now and later move to k2-- that contains an Fsa and also a _k2.RaggedShape that dictates how each of the paths relate to the original supervision segments. But I guess we could draft this pipeline without the data structure.

Supposing we have the Nbest with ragged numbers of paths, we can then add epsilon self-loops and intersect it with the lattices, after moving the 'tokens' to the 'labels' of the lattices; we'd then get the 1-best path and remove epsilons so that we get an Nbest that has just the best path's tokens and no epsilons. (We could define, in class Nbest, a form of intersect() that does the right thing when composing with an Fsa representing an FsaVec; we might also define wrappers for some Fsa operations so they work also on Nbest).

So at this point we have an Nbest with ragged numbers of paths up to 1000 (depending how many unique paths we got) and that is just a linear sequence of arcs, one per token; and it has costs defined per token. (It may also have other types of label and cost that were passively inherited). The way we allocate these costs, e.g. of epsilons and token-repeats, to each token will of course be a little arbitrary-- it's a function of how the epsilon removal algorithm works-- and we can try to figure out later on whether it needs to be changed somehow.

We get the total_scores of this Nbest object; they will be used in determining which ones to use in the first n-best list that we rescore. We can define its total_scores() function so that it returns it as a ragged array, which it logically is.

danpovey commented 3 years ago

OK, the next step is to determine the subset of paths in the Nbest object to rescore. The input to this process is the ragged array of total_scores that we obtained as mentioned above from composing with the lattices, and the immediate output of this would be RaggedInt/Ragged containing the subset of idx01's into the Nbest object that we want to retain. [This will be regular, i.e. we keep the same number from each supervision, even if this means having to use repeats. We'll have to figure out later what to do in case no paths survived in one of the supervisions.] We can use the shape of this to create the new Nbest object, indexing the Fsa of the original Nbest object with the idx01's to get the correct subset. For the very first iteration of our code we can just have this take the most likely n paths, although this is likely not optimal (might not have enough diversity). We can figure this out later. So at this point we still have an Nbest object, but it has a regular structure so will be easier to do rescoring with. Note: it is important that we have the original acoustic and LM scores per token (as the 'scores' in the FSA0, because we will later have a prediction scheme that makes use of these.

danpovey commented 3 years ago

Any rescoring processes we have (e.g. LM rescoring, transformer decoding) should produce an Nbest object with the exact same structure as the one produced in the comment above, i.e. with a regular number of paths per supervision, like 10.

We'll need this exact same structure to be preserved so that our process for finding the n-best paths to rescore will work. It is a selection process, where, from the paths that we have not selected in the 1st round of rescoring, we compute the expected total-score-after-rescoring of the path as a Gaussian distribution, and we rank them by the probability of being better than the best path from the 1st round. This probability requires the Gaussian integral, but we just need ranks so we can rank them by sigma value: i.e. the position, in standard deviations of the distribution, of the best score from the 1st round of rescoring.

This will require us to train a simple model to predict the total-score of a path. For each word-position in each of the remaining paths (i.e. that were not selected in the 1st pass), we want to predict the score for that position after rescoring as a Gaussian. Let an "initial-score" be an elements of the .scores of the n-best lists before neural rescoring, and a "final-score" be an element of the .scores of the n-best lists after neural rescoring. For each position that we want to predict, the inputs are:

The inputs to this model include the mean and variance of the best-matching positions, and a n-gram order. What I mean here, is: for a particular position in a path, we find the longest-matching sequence (i.e. up to and including this word) in any of the n-best lists that we actually rescored; and if there are multiple with the same longest length, we treat them as a set (if there is just one, the variance would be 0). We can also provide this count to the model. The mean and variance means the mean and variance of the scores at those longest-matching positions.

Now, it might look like this process of finding this set of longest-matching words, and computing the mean and variance of the scores, would be very time-consuming. Actually it can be done very efficiently (linear time in the total number of words we are processing, including words in paths that we selected in the 1st pass and those we did not, i.e. queries, and keys), although the algorithms will need to be done on CPU for now because they are too complex to implement on GPU in a short timeframe. I'll describe these algorithms in the next comment on this issue.

danpovey commented 3 years ago

Let me first describe an internal interface for the code that gets the (mean,variance,ngram_order) of the best matching positions that rescored in the 1st round. I'm choosing a level of interface that will let you know the basic picture, but there will be other interfaces above and below. Something like this, assuming it's in Python:

 def get_best_matching_stats(tokens: k2.RaggedInt, scores: Tensor, counts: Tensor, eos: int, min_token: int, max_token: int, max_order: int) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
          """
     For "query" sentences, this function gets the mean and variance of scores from the best matching 
    words-in-context in a set of of provided "key" sentences.  This matching process matches the word and
    the words preceding it, looking for the highest-order match it can find.  It is an efficient implementation
    using suffix arrays (done on CPU for now, since the implementation is not very trivial).  The intended
    application is in estimating the scores of hypothesized transcripts, when we have actually computed
     the scores for only a subset of the hypotheses.

         tokens: a k2 ragged tensor with 3 axes, where the 1st axis is over separate utterances the 2nd axis is
             over different elements of an n-best list, and the 3rd axis is over words.  Some sub-sub-lists represent
             queries, some represent keys to be queried, see scores and counts. This would likely be something like
             the following:
                [ [ [ the, cat, said, eos ], [ the, cat, fed, eos ]], [ [ hi, my, name, is, eos ], [ bye, my, name, is, eos ] ], ... ]
             where the words would actually be the corresponding integers, and they might not even correspond to
             words (might be BPE pieces).
          scores:  a torch.Tensor with shape equal to (tokens.num_elements(),) with dtype==float.
               These are the values that we want to get the mean and variance of (would likely be the
                scores of words or tokens after a first round of n-best list rescoring).  They only represent
                scores in "key positions" (meaning: positions where the corresponding counts value is nonzero); in 
               "query positions", where the counts value is 0, the score is required to be zero.
          counts: a torch.Tensor with shape equal to (tokens.num_elements(),), with dtype==torch.int32,
               where the values should be 1 for positions where the 'scores' are nontrivially defined (representing
               keys) and 0 for positions where the 'scores' are zero (representing queries).   In our example,
             let's take counts to be: [ [ [ 1, 1, 1, 1 ], [ 0, 0, 0, 0 ] ], [ [ 1, 1, 1, 1, 1 ], [ 0, 0, 0, 0, 0 ] ], ... ]
       eos:  The value of the eos (end of sentence) symbol; this is used as an extra padding value before
             the first path for of each utterance, to ensure consistent behavior when matching past the
             beginning of the sentence.
        min_token:  the lowest value of token that might be included in `tokens`, including BOS and EOS symbols;
               may be negative, like -1.
        max_token: might equal the vocabulary size, or simply the maximum token value included
              in this particular example.
        max_order:  the maximum n-gram order to ever match; will be used as a limit on the
              `ngram_order` returned (but not on the actual length of match), and also will be 
               used when a match extends all the way to
              the beginning of a sentence including the implicit beginning-of-sentence symbol.
             (Note: if there is also an explicit bos symbol at the beginning of each sentence, it doesn't
              matter).
     Returns a tuple (mean, var, count, ngram_order), where:
         mean is a torch.Tensor with shape equal to (tokens.num_elements(),), with dtype==torch.float,
          representing the mean of the scores over the set of longest-matching key positions;
          this is defined for both key positions and the query positions, although the caller may not be interested
          in the value at key positions.
        var is a torch.Tensor with shape equal to (tokens.num_elements(),), with dtype==torch.float,
          representing the variance of the scores over the set of longest-matching key positions.
          This is expected to be zero at positions where count equals 1.
        count is a torch.Tensor with shape equal to (tokens.num_elements(),), with dtype==torch.int32,
          representing the number of longest-matching key positions.  This will be 1 if there was only
          a single position of the longest-matching order, and otherwise a larger number (note:
          if no words at all matched, ngram_order would be zero and the mean and variance would encompass
          all positions in all paths for the current utterance.)
       ngram_order  is a torch.Tensor with shape equal to (tokens.num_elements(),), with dtype==torch.int32,
          representing the ngram order of the best-matching word-in-context to the current word, up to
          max_order; or max_order in the case where we match up to the end of a sentence.  Example:
          in the case of 'name', in the 2nd sentence of the 2nd utterance, the ngram_order would be
          2 corresponding to the longest-matching sequence "my name".  In the case of 'fed' in the 2nd
          sentence of the 2nd utterance, the ngram_order would be 0.   In the case of 'cat' in the 2nd
          sentence of the 1st utterance, the ngram_order would equal max_order because we match
          all the way to the beginning of the sentence.
          """
      pass

The implementation of this function will use suffix arrays. For now everything will be done on the CPU. The basic plan is as follows; let's say we do it separately for each utterance. We reverse the order of the words (and possibly utterances; utterance order doesn't matter though), and then append an extra eos symbol, add min_token+1 to everything to avoid zero and negative values, and append a zero, so that, for the 1st utterance above, we'd have: [ eos+n, fed+n, cat+n, the+n, eos+n, said+n, cat+n, the+n, eos+n, 0 ], where n equals min_token + 1. Next we compute the suffix array which is an array of int32, of the same size as the list above, which is a lexicographical sorting of the suffixes of the sentence starting at each position. This can be done reasonably simply in O(n) time, see for example, the C++ code in: https://algo2.iti.kit.edu/documents/jacm05-revised.pdf Next we need to compute the LCP array from the suffix array (array of lengths least common prefixes between successive positions, see:) https://www.geeksforgeeks.org/%C2%AD%C2%ADkasais-algorithm-for-construction-of-lcp-array-from-suffix-array/ The suffix array is an efficient data structure that can be used to simulate algorithms on "suffix tries"; a suffix trie is a compressed tree of suffixes of the string. Viewed as an algorithm on the suffix trie, we can compute the things we need as follows:

csukuangfj commented 3 years ago

I will first implement the ideas in the first comment, i.e., the Nbest class, and try it in https://github.com/k2-fsa/snowfall/pull/198 Will go on to the next comments after it is done.

danpovey commented 3 years ago

Incidentally, regarding padding, speechbrain has something called undo_padding https://github.com/speechbrain/speechbrain/pull/751#issuecomment-879000314 which might possibly be useful. This is just something I noticed; if you disagree please ignore it.

pkufool commented 3 years ago

For the very first iteration of our code we can just have this take the most likely n paths, although this is likely not optimal (might not have enough diversity).

So, these n paths(after rescoring) will be the keys to calculate mean and variance, and the other paths not selected will be queries. Is it right?

danpovey commented 3 years ago

Yes.