jianfch / stable-ts

Transcription, forced alignment, and audio indexing with OpenAI's Whisper
MIT License
1.4k stars 164 forks source link

Whisper WebUI VAD segmenting #287

Open drohack opened 6 months ago

drohack commented 6 months ago

Would it be possible to add a similar option to stable-ts that Whisper-WebUI (https://gitlab.com/aadnk/whisper-webui) uses to split up the audio into segments of perceived audio using VAD?

The main benefit that this provides is that for longer audio files you can leave the condition_on_previous_text to True and have minimal looping errors (where the transcription gets into a infinite loop of providing the same response back). This is due to the audio segments being much shorter so the previous text is only applied to that short audio section.

I've been able to download the Whisper-WebUI code and replace their faster_whisper transcription model with stable-ts. So I know it works. I wouldn't expect it to change anything as it's not actually changing the Whisper code. It just handles splitting up the audio, and merging the transcript back together at the end.

I know this is a pretty big feature request, but I do think it would be beneficial.

dgoryeo commented 5 months ago

Hi @drohack , just saw this. Would you be able to share your modifications to replace their faster_whisper transcription model with stable-ts? Thanks.

drohack commented 5 months ago

I kind of hacked it together by downloading the Whisper-WebUI source code and importing it into my project. Then updating/replacing the fasterWhisperContainer.py file to call stable-ts instead. (faster than reverse engineering how they segmented/batched the audio file and passed it into faster_whisper. Though now my project is just a snapshot of WebUI...)

Here's my code calling WebUI. It looks pretty similar to how you'd call it normally for faster_whisper. The one main difference is that you have to pass the input_sr (audio file sample rate) for stable-ts to be able to transcribe the audio. This is because WebUI converts the audio file into those tensor flow matrixes and stable-ts needs the sample rate to handle that:

            data_models = [
                        {
                            "name": "medium",
                            "url": "medium"
                        },
                        {
                            "name": "large-v3",
                            "url": "large-v3"
                        }
            ]

            models = [ModelConfig(**x) for x in data_models]

            model = FasterWhisperContainer(model_name='large-v3', device='cuda', compute_type='auto', models=models)
            model.ensure_downloaded()
            vad_options = VadOptions(vad='silero-vad', vadMergeWindow=5, vadMaxMergeSize=180,
                                     vadPadding=1, vadPromptWindow=1)
            wwebui = WhisperTranscriber()
            result = wwebui.transcribe_file(model=model, audio_path=audio_file_path, language='Japanese', task='transcribe',
                                            vadOptions=vad_options, input_sr=int(input_sr),
                                            beam_size=5,
                                            word_timestamps=True, condition_on_previous_text=True,
                                            no_speech_threshold=0.35,
                                            )

Here's the relevant lines that I edited from WebUI/src/whisper/fasterWhisperContainer.py so that it calls stable-ts instead:

import stable_whisper
from stable_whisper import WhisperResult

Update the _create_model() to load the stable_whisper model:

    def _create_model(self):
        print("Loading faster whisper model " + self.model_name + " for device " + str(self.device))
        model_config = self._get_model_config()
        model_url = model_config.url

        if model_config.type == "whisper":
            if model_url not in ["tiny", "base", "small", "medium", "large", "large-v1", "large-v2", "large-v3"]:
                raise Exception(
                    "FasterWhisperContainer does not yet support Whisper models. Use ct2-transformers-converter to convert the model to a faster-whisper model.")
            if model_url == "large":
                # large is an alias for large-v3
                model_url = "large-v3"

        device = self.device

        if (device is None):
            device = "auto"

        model = stable_whisper.load_faster_whisper(model_url, device=device, compute_type=self.compute_type, num_workers=2)
        return model

Update invoke() to run the transcribe_stable() function, and pass it back in the same format that WebUI is expecting. (I did have to remove some of the information being passed back, like progress_listener, and duration. But this is just extra information that's not needed.)

        # See if supress_tokens is a string - if so, convert it to a list of ints
        decodeOptions["suppress_tokens"] = self._split_suppress_tokens(suppress_tokens)

        initial_prompt = self.prompt_strategy.get_segment_prompt(segment_index, prompt, detected_language) \
                           if self.prompt_strategy else prompt

        result: WhisperResult = model.transcribe_stable(audio, \
            language=language_code if language_code else detected_language, task=self.task, \
            initial_prompt=initial_prompt,
            **decodeOptions
        )

        segments = []

        for segment in result.segments:
            segments.append(segment)

            #if progress_listener is not None:
            #    progress_listener.on_progress(segment.end, info.duration)
            if verbose:
                print("[{}->{}] {}".format(format_timestamp(segment.start, True), format_timestamp(segment.end, True),
                                          segment.text))

        text = " ".join([segment.text for segment in segments])

        # Convert the segments to a format that is easier to serialize
        whisper_segments = [{
            "text": segment.text,
            "start": segment.start,
            "end": segment.end,

            # Extra fields added by faster-whisper
            "words": [{
                "start": word.start,
                "end": word.end,
                "word": word.word,
                "probability": word.probability
            } for word in (segment.words if segment.words is not None else []) ]
        } for segment in segments]

        result = {
            "segments": whisper_segments,
            "text": text,
            "language": result.language if result else None,

            # Extra fields added by faster-whisper
            "language_probability": None,
            "duration": None
        }

        # If we have a prompt strategy, we need to increment the current prompt
        if self.prompt_strategy:
            self.prompt_strategy.on_segment_finished(segment_index, prompt, detected_language, result)

        if progress_listener is not None:
            progress_listener.on_finished()
        return result

It really was just a few lines of code changed.

I do have my full project posted here: https://github.com/drohack/AutoSubVideos

dgoryeo commented 5 months ago

Thanks so much @drohack !