huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135k stars 27.01k forks source link

Beam search decoding and language model integration for Wav2Vec2ForCTC models #11283

Closed tanujjain closed 3 years ago

tanujjain commented 3 years ago
  1. AFAIK, Wav2Vec2ForCTCTokenizer.decode method only provides greedy decoding. Is there a Beamsearch implementation for CTC available yet?
  2. Also, as it is a common norm in ASR modelling, language models are also generally added on top of the acoustic model. It would also be nice to have a possibility of appending a pretrained Language model which gets taken into consideration at the beamsearch decoding time. Not sure if there's an out-of-box solution implemented for that yet?

I'm also aware of efforts to integrate a language model in #10794 and have had a look at the notebook here. Although it is a nice, simple way to integrate an LM, it is suboptimal when considering CTC semantics. A more appropriate approach would be the one described in this paper and explained in this distilpub blog. Would be great to have these features added (if they are already not there and I somehow missed them).

patrickvonplaten commented 3 years ago

Hey @tanujjain,

We are very interested in adding beam search for Wav2Vec2 + LM support in general, but sadly don't find the time to do so at the moment. We would be really happy about a contribution if you want to give it a try.

As a start we could add the logic to examples/research_projects/wav2vec2 and if it's clean then move to upstream to src/transformers

tanujjain commented 3 years ago

@patrickvonplaten Sure, I'll give it a go.

deepang17 commented 3 years ago

Hello @patrickvonplaten and @tanujjain,

I have already worked with prefix beam search decoding with language models for wav2vec2 and would like to implement it for huggingface, if you guys are okay with it.

patrickvonplaten commented 3 years ago

PRs are very much welcome!

sarim-zafar commented 3 years ago

Any update on this? Specifically any transformer based lm that one can use with wav2vec 2.0?

tanujjain commented 3 years ago

As a quick solution, I used the code by original author of the algo which can be found here.

import numpy as np
import math
import collections

NEG_INF = -float("inf")

def make_new_beam():
    fn = lambda : (NEG_INF, NEG_INF)
    return collections.defaultdict(fn)

def logsumexp(*args):
    """
    Stable log sum exp.
    """
    if all(a == NEG_INF for a in args):
        return NEG_INF
    a_max = max(args)
    lsp = math.log(sum(math.exp(a - a_max)
                      for a in args))
    return a_max + lsp

def decode(probs, beam_size=100, blank=0):
    """
    Performs inference for the given output probabilities.
    Arguments:
      probs: The output probabilities (e.g. post-softmax) for each
        time step. Should be an array of shape (time x output dim).
      beam_size (int): Size of the beam to use during inference.
      blank (int): Index of the CTC blank label.
    Returns the output label sequence and the corresponding negative
    log-likelihood estimated by the decoder.
    """
    T, S = probs.shape
    probs = np.log(probs)

    # Elements in the beam are (prefix, (p_blank, p_no_blank))
    # Initialize the beam with the empty sequence, a probability of
    # 1 for ending in blank and zero for ending in non-blank
    # (in log space).
    beam = [(tuple(), (0.0, NEG_INF))]

    for t in range(T): # Loop over time
        next_beam = make_new_beam() # A default dictionary to store the next step candidates.
        for s in range(S): # Loop over vocab
            p = probs[t, s]
            # The variables p_b and p_nb are respectively the
          # probabilities for the prefix given that it ends in a
          # blank and does not end in a blank at this time step.
            for prefix, (p_b, p_nb) in beam: # Loop over beam
                # If we propose a blank the prefix doesn't change.
                # Only the probability of ending in blank gets updated
                if s == blank:
                    n_p_b, n_p_nb = next_beam[prefix]
                    n_p_b = logsumexp(n_p_b, p_b + p, p_nb + p)
                    next_beam[prefix] = (n_p_b, n_p_nb)
                    continue
                # Extend the prefix by the new character s and add it to
                # the beam. Only the probability of not ending in blank
                # gets updated.
                end_t = prefix[-1] if prefix else None
                n_prefix = prefix + (s,)
                n_p_b, n_p_nb = next_beam[n_prefix]
                if s != end_t:
                    n_p_nb = logsumexp(n_p_nb, p_b + p, p_nb + p)
                else:
                      # We don't include the previous probability of not ending
                      # in blank (p_nb) if s is repeated at the end. The CTC
                      # algorithm merges characters not separated by a blank.
                    n_p_nb = logsumexp(n_p_nb, p_b + p)

                # *NB* this would be a good place to include an LM score.
                next_beam[n_prefix] = (n_p_b, n_p_nb) ## add lm here
                # If s is repeated at the end we also update the unchanged
                # prefix. This is the merging case.
                if s == end_t:
                    n_p_b, n_p_nb = next_beam[prefix]
                    n_p_nb = logsumexp(n_p_nb, p_nb + p)
                    next_beam[prefix] = (n_p_b, n_p_nb)
        # Sort and trim the beam before moving on to the
        # next time-step.
        beam = sorted(next_beam.items(),
            key=lambda x : logsumexp(*x[1]),
            reverse=True)
        beam = beam[:beam_size]
    best = beam[0]
    return best[0], -logsumexp(*best[1])

# Try the algo on an example
time = 50
output_dim = 20
batch_size = 16

batch_probs = np.random.rand(batch_size, time, output_dim)
decoded_batch = []
for b in batch_probs:
    norm_b = b/np.sum(b, axis=1, keepdims=True)
    decoded_batch.append(decode(norm_b, beam_size=3)[0])

Trying to add a language model (for german) like so:


from transformers import AutoTokenizer, AutoModelWithLMHead
tokenizer_de = AutoTokenizer.from_pretrained("dbmdz/german-gpt2")
model_de = AutoModelWithLMHead.from_pretrained("dbmdz/german-gpt2", return_dict_in_generate=True)

def lm_prob(sentence):
    last_word_token = tokenizer_de.encode(sentence.split(' ')[-1])
    earlier_sentence = ' '.join(sentence.split(' ')[:-1])
    input_ids_earlier_sent = tokenizer_de.encode(earlier_sentence, return_tensors="pt")  # tokenize rest of the sentence
    generated_outputs_lm = model_de.generate(input_ids_earlier_sent,
                                   max_length=len(input_ids_earlier_sent[0]) + 1,
                                   do_sample=True, 
                                   num_return_sequences=1,
                                   output_scores=True)
    sftmax_prob_lm = generated_outputs_lm.scores[0].softmax(-1)
    prob = sftmax_prob_lm[0, last_word_token]
    return prob

The lm snippet should give the prob of having the last word in a beam given all the other preceding characters, but the probabilities for the words I expect are almost always close to zero, so still working on figuring out how better to use the LM. Hence, haven't integrated the LM with the above snippet.

As for a decent implementation for beamsearchforctc, I'm thinking on the lines of running the above algo (not the same code obviously) with each sequence in the batch running an independent beamsearch on a different thread/process.

Anyone with less complex implementational ideas?

Found another implementation here (without consideration for batch inference).

phtephanx commented 3 years ago

As for a decent implementation for beamsearchforctc, I'm thinking on the lines of running the above algo (not the same code obviously) with each sequence in the batch running an independent beamsearch on a different thread/process.

There you go: https://github.com/mozilla/DeepSpeech/blob/master/native_client/ctcdecode/ctc_beam_search_decoder.cpp#L287

I'd highly encourage to also consider returning the frames where the probability of the token spikes as it can be used for alignment. Mozilla did it in their implementation and it works quite nicely.

Is there any restriction on the programming language? The computational complexity of the algorithm is quite high and ctc beam search decoding often the bottleneck.

patrickvonplaten commented 3 years ago

I think we can try to add a dependency to wav2letter: https://github.com/flashlight/wav2letter and add LM decoding as explained here on fairseq: https://github.com/pytorch/fairseq/blob/master/examples/wav2vec/README.md#evaluating-a-ctc-model . It would be awesome if we manage to create a nice run_wav2vec2_eval_with_lm.py script that people can use out of the box with every wav2vec2 model. We can also make a nice blog post out of this and publish it on our blog :-)

github-actions[bot] commented 3 years ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

patrickvonplaten commented 3 years ago

ping

github-actions[bot] commented 3 years ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

farisalasmary commented 3 years ago

For future developers, you may find this implementation useful. I used the simplest code possible to develop it https://github.com/farisalasmary/wav2vec2-kenlm

patrickvonplaten commented 3 years ago

I'm now working on this topic full time.

We will most likely foster a closer collaboration between pyctcdecode and Transformers. Here is a github repo that shows how to use pyctcdecode with Wav2Vec2 for LM supported decoding. It works quite well with KenLM.