shashikg / WhisperS2T

An Optimized Speech-to-Text Pipeline for the Whisper Model Supporting Multiple Inference Engine
MIT License
264 stars 24 forks source link

[`large-v3`] Error during transcription: Invalid input features shape: expected an input with shape (3, 80, 3000), but got an input with shape (3, 128, 3000) instead #51

Open twardoch opened 4 months ago

twardoch commented 4 months ago

Cell 1

!apt install ffmpeg
!pip install whisper-s2t yt-dlp gradio pydantic ffmpeg-python

Cell 2

import logging
from pathlib import Path
import whisper_s2t

from google.colab import drive

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Configuration
class Config:
    model_identifier = "large-v3" # This causes a problem
    backend = "CTranslate2"
    output_format = "vtt"
    max_workers = 16
    path_root = "/content/drive"
    cwd = Path(path_root, "MyDrive/Colab Notebooks/YouTube Videos")

drive.mount(Config.path_root)

Cell 3

whisper_s2t_model = whisper_s2t.load_model(
    model_identifier=Config.model_identifier,
    backend=Config.backend,
    asr_options={"word_timestamps": True},
    # n_mels=128 # This doesn't matter
)

Cell 4

import asyncio
import os
import shutil
from concurrent.futures import ThreadPoolExecutor

import ffmpeg
import yt_dlp
from pydantic import BaseModel

# Pydantic model for VideoToTranscribe
class VideoToTranscribe(BaseModel):
    video_path: Path
    audio_path: Path
    metadata: dict | None = None
    lang_code: str = "en"
    initial_prompt: str | None = None
    vtt_path: Path

class VideoTranscriptor:
    def __init__(self, cwd: Path, whisper_s2t_model):
        self.cwd = cwd
        self.input_videos_dir = cwd / "input_videos"
        self.input_audios_dir = cwd / "input_audios"
        self.transcribed_dir = cwd / "transcribed"

        # Create directories if they don't exist
        self.input_videos_dir.mkdir(parents=True, exist_ok=True)
        self.input_audios_dir.mkdir(parents=True, exist_ok=True)
        self.transcribed_dir.mkdir(parents=True, exist_ok=True)
        self.whisper_s2t_model = whisper_s2t_model

    async def download_youtube_videos(self, url: str):
        ydl_opts = {
            "format": "bestvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best",
            "outtmpl": str(self.input_videos_dir / "%(id)s.%(ext)s"),
        }

        with yt_dlp.YoutubeDL(ydl_opts) as ydl:
            ydl.download([url])

    async def transcribe_audio(
        self, audio_paths: list[Path], lang_codes: list[str], tasks: list[str]
    ):
        vtt_paths = [
            self.transcribed_dir / f"{audio_path.stem}.{Config.output_format}"
            for audio_path in audio_paths
        ]

        out = self.whisper_s2t_model.transcribe_with_vad(
            [str(audio_path) for audio_path in audio_paths],
            lang_codes=lang_codes,
            tasks=tasks,
            initial_prompts=[None] * len(audio_paths),
            batch_size=Config.max_workers,
        )

        whisper_s2t.write_outputs(
            out,
            format=Config.output_format,
            op_files=[str(vtt_path) for vtt_path in vtt_paths],
        )

        return vtt_paths

    async def process_videos(self, lang_code: str, output_lang_code: str):
        video_paths = list(self.input_videos_dir.glob("*.mp4"))

        def extract_audio(video_path: Path):
            audio_path = self.input_audios_dir / f"{video_path.stem}.wav"

            try:
                (
                    ffmpeg.input(str(video_path))
                    .output(str(audio_path), acodec="pcm_s16le", ar=16000, ac=1)
                    .overwrite_output()
                    .run(capture_stdout=True, capture_stderr=True)
                )
            except ffmpeg.Error as e:
                logger.error(
                    f"Error while extracting audio from {video_path}: {e.stderr.decode()}"
                )
                raise e

            return audio_path

        with ThreadPoolExecutor(max_workers=Config.max_workers) as executor:
            audio_extraction_tasks = [
                asyncio.get_event_loop().run_in_executor(
                    executor, extract_audio, video_path
                )
                for video_path in video_paths
            ]
            audio_paths = await asyncio.gather(*audio_extraction_tasks)

            task = "transcribe" if lang_code == output_lang_code else "translate"
            tasks = [task] * len(audio_paths)
            lang_codes = [output_lang_code] * len(audio_paths)

            vtt_paths = await self.transcribe_audio(audio_paths, lang_codes, tasks)

        videos_to_transcribe = [
            VideoToTranscribe(
                video_path=video_path, audio_path=audio_path, vtt_path=vtt_path
            )
            for video_path, audio_path, vtt_path in zip(
                video_paths, audio_paths, vtt_paths
            )
        ]

        return videos_to_transcribe

    async def cleanup(self, videos_to_transcribe: list[VideoToTranscribe]):
        for video in videos_to_transcribe:
            if video.vtt_path.exists():
                if video.video_path.exists():
                    shutil.move(str(video.video_path), str(self.transcribed_dir))
                    os.remove(str(video.audio_path))
                else:
                    shutil.move(str(video.audio_path), str(self.transcribed_dir))

    async def transcribe(self, youtube_url: str, lang_code: str, output_lang_code: str):
        if youtube_url:
            logger.info(f"Downloading YouTube video(s) from: {youtube_url}")
            await self.download_youtube_videos(youtube_url)

        logger.info("Processing videos...")
        videos_to_transcribe = await self.process_videos(lang_code, output_lang_code)

        logger.info("Cleaning up temporary files...")
        await self.cleanup(videos_to_transcribe)

        return f"Transcription completed. Files saved in {self.transcribed_dir}"

Cell 5


import gradio as gr

# Gradio UI

def launch_ui():
    cwd = Path(Config.cwd)
    transcriptor = VideoTranscriptor(cwd, whisper_s2t_model)

    async def transcribe_wrapper(
        youtube_url: str, lang_code: str, output_lang_code: str
    ):
        try:
            result = await transcriptor.transcribe(
                youtube_url, lang_code, output_lang_code
            )
            return result
        except Exception as e:
            logger.error(f"Error during transcription: {str(e)}")
            return f"An error occurred during transcription: {str(e)}"

    input_components = [
        gr.Textbox(
            label="YouTube URL (optional)",
            placeholder="Enter a YouTube video or playlist URL",
        ),
        gr.Textbox(label="Source Language Code", value="en"),
        gr.Textbox(label="Target Language Code", value="en"),
    ]

    iface = gr.Interface(
        fn=transcribe_wrapper,
        inputs=input_components,
        outputs="text",
        title="YouTube Video Transcriptor",
        description="Transcribe YouTube videos or local video files using Whisper",
        allow_flagging="never",
    )

    iface.launch(debug=False, share=True)

if __name__ == "__main__":
    launch_ui()

When trying to run this code with large-v3 model identifier, I keep getting:

ERROR:__main__:Error during transcription: Invalid input features shape: expected an input with shape (3, 80, 3000), but got an input with shape (3, 128, 3000) instead

With large-v2, it works fine.

AmgadHasan commented 4 months ago

Try uncommenting the n_mels line

whisper_s2t_model = whisper_s2t.load_model(
    model_identifier=Config.model_identifier,
    backend=Config.backend,
    asr_options={"word_timestamps": True},
    # n_mels=128 # This doesn't matter
)
twardoch commented 4 months ago

By "this doesn't work" I meant: it fails if the parameter is commented or uncommented.

shashikg commented 4 months ago

@twardoch this is a bug for the aligner model. By default for alignment tiny model is used which expects n_mels to be of size 80 but large-v3 expects n_mels to be of size 128. Since same pre processor is getting shared, you are getting this issue.

I will fix this in next release.

Meanwhile for using large-v3 disable word timestamps (which should fix your issue):

asr_options={"word_timestamps": False},
twardoch commented 4 months ago

Thanks! I do want them wordstamps though ;)

aleksandr-smechov commented 4 months ago

@twardoch You can add a separate preprocessor with a fixed number of n_mels as shown in this commit