kensho-technologies / pyctcdecode

A fast and lightweight python-based CTC beam search decoder for speech recognition.
Apache License 2.0
421 stars 89 forks source link

PyctcDecode fails when running spawn context (multiprocessing) #65

Closed elsheikh21 closed 1 year ago

elsheikh21 commented 2 years ago

I have opened a Transformers PR, yet the PR failed due to changing context from fork to spawn in

from multiprocessing import get_context
.
.
.
pool = get_context("spawn").Pool(num_processes)

in the following file src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py

My question is "How to run pyctcdecode batch decode on WindowsOS?" Is there any way I can help? or contribute to fix that issue?

lopez86 commented 2 years ago

Have you done any more investigating into this? I haven't tried to use pyctcdecode with Windows and I'm not aware of others, so it wouldn't be surprising if there are some bugs around multiprocessing on Windows. Can you make a minimal example using pyctcdecode where this fails, especially if it doesn't need other dependencies like transformers, or is this specific to some combination of transformers and pyctcdecode?

elsheikh21 commented 2 years ago

Hello I have tried to use this package with transformers package for NLP with Model Wav2Vec2WithLM to transcribe a wav audio file, please find the code attached below

import torch
from transformers import Wav2Vec2ForCTC, Wav2Vec2ProcessorWithLM
from datasets import load_dataset
import soundfile as sf

model_name = "jonatasgrosman/wav2vec2-large-xlsr-53-english"
model = Wav2Vec2ForCTC.from_pretrained(model_name)
processor_path = path_join(getcwd(), "stt_assets", "stt_model")
processor = Wav2Vec2ProcessorWithLM.from_pretrained(processor_path)

dataset = load_dataset("timit_asr", split="test").shuffle().shuffle().select(range(100))
char_translations = str.maketrans({"-": " ", ",": "", ".": "", "?": ""})

def prepare_example(example):
    example["speech"], _ = sf.read(example["file"])
    example["text"] = example["text"].translate(char_translations)
    example["text"] = " ".join(example["text"].split())  # clean up whitespace
    example["text"] = example["text"].lower()
    return example

dataset = dataset.map(prepare_example, remove_columns=["file"])

pprint(dataset)
features = processor(speech, sampling_rate=16_000, return_tensors="pt", padding=True)

with torch.no_grad():
    logits = model(**features).logits

# logits shape is torch.Size([100, 304, 33])
transcription = processor.batch_decode(logits)
# EXCEPTION IS RAISED in `processor.batch_decode()` ValueError: cannot find context for 'fork'
print(transcription)
lopez86 commented 2 years ago

I was looking into this a bit, and what's happening is that because the language model is included as a class variable, fork works without having to reload the language model, but spawn ends up creating processes where the language model is missing. If the language model is instead made an instance variable, both fork and spawn, but each process needs to reload the language model, which will likely wipe out any performance improvements from using multiprocessing. I'm not sure what a good solution to this would be

lopez86 commented 2 years ago

I've added some code that will allow for either not using multiprocessing or if a spawn context pool is found will bypass multiprocessing as well. It's not a true solution but it should unblock things at least. If anyone has some way to enable actually using the pool with a spawn context without hurting performance, please discuss or consider making a PR

elsheikh21 commented 2 years ago

@lopez86 excuse me for taking too much time to respond, I had some personal issues to resort to. Thanks a lot for your help, I am discussing that with Patrick Von PLaten in HuggingFace issue that was created earlier. thanks a lot for your help and time