Closed tanujjain closed 3 years ago
Hey @tanujjain,
We are very interested in adding beam search for Wav2Vec2 + LM support in general, but sadly don't find the time to do so at the moment. We would be really happy about a contribution if you want to give it a try.
As a start we could add the logic to examples/research_projects/wav2vec2
and if it's clean then move to upstream to src/transformers
@patrickvonplaten Sure, I'll give it a go.
Hello @patrickvonplaten and @tanujjain,
I have already worked with prefix beam search decoding with language models for wav2vec2 and would like to implement it for huggingface, if you guys are okay with it.
PRs are very much welcome!
Any update on this? Specifically any transformer based lm that one can use with wav2vec 2.0?
As a quick solution, I used the code by original author of the algo which can be found here.
import numpy as np
import math
import collections
NEG_INF = -float("inf")
def make_new_beam():
fn = lambda : (NEG_INF, NEG_INF)
return collections.defaultdict(fn)
def logsumexp(*args):
"""
Stable log sum exp.
"""
if all(a == NEG_INF for a in args):
return NEG_INF
a_max = max(args)
lsp = math.log(sum(math.exp(a - a_max)
for a in args))
return a_max + lsp
def decode(probs, beam_size=100, blank=0):
"""
Performs inference for the given output probabilities.
Arguments:
probs: The output probabilities (e.g. post-softmax) for each
time step. Should be an array of shape (time x output dim).
beam_size (int): Size of the beam to use during inference.
blank (int): Index of the CTC blank label.
Returns the output label sequence and the corresponding negative
log-likelihood estimated by the decoder.
"""
T, S = probs.shape
probs = np.log(probs)
# Elements in the beam are (prefix, (p_blank, p_no_blank))
# Initialize the beam with the empty sequence, a probability of
# 1 for ending in blank and zero for ending in non-blank
# (in log space).
beam = [(tuple(), (0.0, NEG_INF))]
for t in range(T): # Loop over time
next_beam = make_new_beam() # A default dictionary to store the next step candidates.
for s in range(S): # Loop over vocab
p = probs[t, s]
# The variables p_b and p_nb are respectively the
# probabilities for the prefix given that it ends in a
# blank and does not end in a blank at this time step.
for prefix, (p_b, p_nb) in beam: # Loop over beam
# If we propose a blank the prefix doesn't change.
# Only the probability of ending in blank gets updated
if s == blank:
n_p_b, n_p_nb = next_beam[prefix]
n_p_b = logsumexp(n_p_b, p_b + p, p_nb + p)
next_beam[prefix] = (n_p_b, n_p_nb)
continue
# Extend the prefix by the new character s and add it to
# the beam. Only the probability of not ending in blank
# gets updated.
end_t = prefix[-1] if prefix else None
n_prefix = prefix + (s,)
n_p_b, n_p_nb = next_beam[n_prefix]
if s != end_t:
n_p_nb = logsumexp(n_p_nb, p_b + p, p_nb + p)
else:
# We don't include the previous probability of not ending
# in blank (p_nb) if s is repeated at the end. The CTC
# algorithm merges characters not separated by a blank.
n_p_nb = logsumexp(n_p_nb, p_b + p)
# *NB* this would be a good place to include an LM score.
next_beam[n_prefix] = (n_p_b, n_p_nb) ## add lm here
# If s is repeated at the end we also update the unchanged
# prefix. This is the merging case.
if s == end_t:
n_p_b, n_p_nb = next_beam[prefix]
n_p_nb = logsumexp(n_p_nb, p_nb + p)
next_beam[prefix] = (n_p_b, n_p_nb)
# Sort and trim the beam before moving on to the
# next time-step.
beam = sorted(next_beam.items(),
key=lambda x : logsumexp(*x[1]),
reverse=True)
beam = beam[:beam_size]
best = beam[0]
return best[0], -logsumexp(*best[1])
# Try the algo on an example
time = 50
output_dim = 20
batch_size = 16
batch_probs = np.random.rand(batch_size, time, output_dim)
decoded_batch = []
for b in batch_probs:
norm_b = b/np.sum(b, axis=1, keepdims=True)
decoded_batch.append(decode(norm_b, beam_size=3)[0])
Trying to add a language model (for german) like so:
from transformers import AutoTokenizer, AutoModelWithLMHead
tokenizer_de = AutoTokenizer.from_pretrained("dbmdz/german-gpt2")
model_de = AutoModelWithLMHead.from_pretrained("dbmdz/german-gpt2", return_dict_in_generate=True)
def lm_prob(sentence):
last_word_token = tokenizer_de.encode(sentence.split(' ')[-1])
earlier_sentence = ' '.join(sentence.split(' ')[:-1])
input_ids_earlier_sent = tokenizer_de.encode(earlier_sentence, return_tensors="pt") # tokenize rest of the sentence
generated_outputs_lm = model_de.generate(input_ids_earlier_sent,
max_length=len(input_ids_earlier_sent[0]) + 1,
do_sample=True,
num_return_sequences=1,
output_scores=True)
sftmax_prob_lm = generated_outputs_lm.scores[0].softmax(-1)
prob = sftmax_prob_lm[0, last_word_token]
return prob
The lm snippet should give the prob of having the last word in a beam given all the other preceding characters, but the probabilities for the words I expect are almost always close to zero, so still working on figuring out how better to use the LM. Hence, haven't integrated the LM with the above snippet.
As for a decent implementation for beamsearchforctc, I'm thinking on the lines of running the above algo (not the same code obviously) with each sequence in the batch running an independent beamsearch on a different thread/process.
Anyone with less complex implementational ideas?
Found another implementation here (without consideration for batch inference).
As for a decent implementation for beamsearchforctc, I'm thinking on the lines of running the above algo (not the same code obviously) with each sequence in the batch running an independent beamsearch on a different thread/process.
There you go: https://github.com/mozilla/DeepSpeech/blob/master/native_client/ctcdecode/ctc_beam_search_decoder.cpp#L287
I'd highly encourage to also consider returning the frames where the probability of the token spikes as it can be used for alignment. Mozilla did it in their implementation and it works quite nicely.
Is there any restriction on the programming language? The computational complexity of the algorithm is quite high and ctc beam search decoding often the bottleneck.
I think we can try to add a dependency to wav2letter: https://github.com/flashlight/wav2letter and add LM decoding as explained here on fairseq: https://github.com/pytorch/fairseq/blob/master/examples/wav2vec/README.md#evaluating-a-ctc-model . It would be awesome if we manage to create a nice run_wav2vec2_eval_with_lm.py
script that people can use out of the box with every wav2vec2 model. We can also make a nice blog post out of this and publish it on our blog :-)
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
ping
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
For future developers, you may find this implementation useful. I used the simplest code possible to develop it https://github.com/farisalasmary/wav2vec2-kenlm
I'm now working on this topic full time.
We will most likely foster a closer collaboration between pyctcdecode and Transformers. Here is a github repo that shows how to use pyctcdecode
with Wav2Vec2 for LM supported decoding. It works quite well with KenLM.
Wav2Vec2ForCTCTokenizer.decode
method only provides greedy decoding. Is there a Beamsearch implementation for CTC available yet?I'm also aware of efforts to integrate a language model in #10794 and have had a look at the notebook here. Although it is a nice, simple way to integrate an LM, it is suboptimal when considering CTC semantics. A more appropriate approach would be the one described in this paper and explained in this distilpub blog. Would be great to have these features added (if they are already not there and I somehow missed them).