kensho-technologies / pyctcdecode

A fast and lightweight python-based CTC beam search decoder for speech recognition.
Apache License 2.0
416 stars 89 forks source link

How are partial hypotheses managed ? #48

Closed TParcollet closed 2 years ago

TParcollet commented 2 years ago

Hi there!

May I ask how partial hypotheses are handled in your n-gram rescoring implementation? For instance, what if the AM outputs BPE tokens while the n-gram LM is at the word level? How is rescoring performed to ensure that all hypotheses are checked and the rescoring isn't applied only once the first space token is encountered?

Thanks!

gkucsko commented 2 years ago

Hi Titouan, thanks for the question. The token merging happens on the decoder level, before it gets sent off to the LM scoring part, so from an lm perspective there isn't much of a difference compared to character level. As for how it's scored: any finished word (with a trailing space decoded, or in bpe followed by a _ token) will be part of kenlm scoring, and any unfinished word will be scored using a trie to check if the partial word is OOV (including a consideration for length). Does that make sense? We have also tried to normalize the partial scores to the unigram LM probabilities (ie summing up all potential completion probabilities) but we ended up with slighly worse results and this simpler approach. Best, Georg

TParcollet commented 2 years ago

Thanks for the answer! So partial words are rescored using a tree-like structure?

gkucsko commented 2 years ago

yeah, basically a fast check on whether the prefix exists within the vocabulary. The additional parameters unk_score_offset and AVG_TOKEN_LEN allow tuning on how it should be weighed relative to the LM scored fragments. The most important contribution here is to heavily punish a prefix that is already out of vocabulary, but not yet followed by a space and therefore not yet included in the kenlm scoring part.

    def score_partial_token(self, partial_token: str) -> float:
        """Get partial token score."""
        if self._char_trie is None:
            is_oov = 1.0
        else:
            is_oov = int(self._char_trie.has_node(partial_token) == 0)
        unk_score = self.unk_score_offset * is_oov
        # if unk token length exceeds expected length then additionally decrease score
        if len(partial_token) > AVG_TOKEN_LEN:
            unk_score = unk_score * len(partial_token) / AVG_TOKEN_LEN
        return unk_score

(and within this function the notion of bpe is no longer present since the merging already happened)

gkucsko commented 2 years ago

closing, please feel free to reopen if more questions