kensho-technologies / pyctcdecode

A fast and lightweight python-based CTC beam search decoder for speech recognition.
Apache License 2.0
416 stars 89 forks source link

Getting key error form the pyctcdecode package, any idea ? #41

Closed cleancoder7 closed 2 years ago

cleancoder7 commented 2 years ago
Traceback (most recent call last):
  File "/usr/lib/python3.8/multiprocessing/pool.py", line 125, in worker
    result = (True, func(*args, **kwds))
  File "/usr/lib/python3.8/multiprocessing/pool.py", line 48, in mapstar
    return list(map(*args))
  File "/usr/local/lib/python3.8/dist-packages/pyctcdecode/decoder.py", line 547, in _decode_beams_mp_safe
    decoded_beams = self.decode_beams(
  File "/usr/local/lib/python3.8/dist-packages/pyctcdecode/decoder.py", line 525, in decode_beams
    decoded_beams = self._decode_logits(
  File "/usr/local/lib/python3.8/dist-packages/pyctcdecode/decoder.py", line 329, in _decode_logits
    language_model = BeamSearchDecoderCTC.model_container[self._model_key]
KeyError: b'\xf0\xaaD\x92+\x90\x16\xc9 \xf5,\xb4\x10\xb1y\x8e'
gkucsko commented 2 years ago

Hi, looks like an issue with model cleanup. We need this container trick to allow multiprocessing in python without creating copies of the model. Can you post minimum viable example where you get the error?

gkucsko commented 2 years ago

closing for inactivity, feel free to re-open if error persists

gfigueroa commented 2 years ago

@gkucsko I'm getting the same error when using the Wav2Vec2ProcessorWithLM transformer as exemplified in @patrickvonplaten's latest blog (https://huggingface.co/blog/wav2vec2-with-ngram).

Here is a minimum viable example that throws the error (last line):

import torch
from datasets import load_dataset
from transformers import Wav2Vec2ProcessorWithLM, Wav2Vec2ForCTC

dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
audio_sample = dataset[2]

processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm")
model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm")

inputs = processor(audio_sample["audio"]["array"], sampling_rate=audio_sample["audio"]["sampling_rate"], return_tensors="pt")

with torch.no_grad():
  logits = model(**inputs).logits

transcription = processor.batch_decode(logits.numpy()).text
gkucsko commented 2 years ago

Great thanks, can reproduce. Reopening... For now, you can probably just use .decode. I will look into what's happening with multicore .batch_decode. Things are a bit sneaky here cause of GIL.

gfigueroa commented 2 years ago

Thanks for the quick reply @gkucsko. I'm getting a similar error with transcription = processor.decode(logits.numpy()).text:

KeyError                                  Traceback (most recent call last)
<ipython-input-7-d0acd36503cc> in <module>
----> 1 transcription = processor.decode(logits.numpy()).text
      2 transcription[0].lower()

~/.pyenv/versions/3.8.2/lib/python3.8/site-packages/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py in decode(self, logits, beam_width, beam_prune_logp, token_min_logp, hotwords, hotword_weight)
    358 
    359         # pyctcdecode
--> 360         decoded_beams = self.decoder.decode_beams(
    361             logits,
    362             beam_width=beam_width,

~/.pyenv/versions/3.8.2/lib/python3.8/site-packages/pyctcdecode/decoder.py in decode_beams(self, logits, beam_width, beam_prune_logp, token_min_logp, prune_history, hotwords, hotword_weight, lm_start_state)
    523             # convert logits into log probs
    524             logits = np.clip(_log_softmax(logits, axis=1), np.log(MIN_TOKEN_CLIP_P), 0)
--> 525         decoded_beams = self._decode_logits(
    526             logits,
    527             beam_width=beam_width,

~/.pyenv/versions/3.8.2/lib/python3.8/site-packages/pyctcdecode/decoder.py in _decode_logits(self, logits, beam_width, beam_prune_logp, token_min_logp, prune_history, hotword_scorer, lm_start_state)
    345             for idx_char in idx_list:
    346                 p_char = logit_col[idx_char]
--> 347                 char = self._idx2vocab[idx_char]
    348                 for (
    349                     text,

KeyError: 512
gkucsko commented 2 years ago

that looks like an unrelated error having to do with a vocabulary mismatch. probably you need to squash the first (batch) dimension. in the example you provided .decode worked for me locally.

Screen Shot 2022-01-18 at 12 54 28 PM
gfigueroa commented 2 years ago

This worked, thanks!

gkucsko commented 2 years ago

@gfigueroa what hardware and OS / python version are you using? I'm reproducing this on mac and python 3.9 but not on ubuntu and py3.9 or mac and py3.7...

gkucsko commented 2 years ago

@patrickvonplaten this may be fixed by changing the Pool from 'spawn' to 'fork', meaning Wav2Vec2ProcessorWithLM.batch_decode to multiprocessing.get_context('fork').Pool() see here the different defaults: https://stackoverflow.com/questions/64095876/multiprocessing-fork-vs-spawn

patrickvonplaten commented 2 years ago

Yeah that makes sense! Cool, I'll reply on the issue

patrickvonplaten commented 2 years ago

@gfigueroa - could you verify whether you still encounter the problem when using this branch of Transformers: https://github.com/huggingface/transformers/pull/15247?

gfigueroa commented 2 years ago

@gkucsko not sure if still useful, but I'm running this on python 3.8 on mac Big Sur (11.4). @patrickvonplaten I used your branch and batch_decode() works as specified in your blog post, thanks!

gkucsko commented 2 years ago

Docs/Tests and Tutorials will be update with https://github.com/kensho-technologies/pyctcdecode/pull/51. Thanks for the help, closing.