SYSTRAN / faster-whisper

Faster Whisper transcription with CTranslate2
MIT License
12.59k stars 1.05k forks source link

Feeding raw audio data to faster whisper over websockets #1077

Open FredTheNoob opened 1 month ago

FredTheNoob commented 1 month ago

I have the following frontend code which sends audio data over a websocket in the browser (using the microphone):

const webSocket = new WebSocket('ws://127.0.0.1:3000');

webSocket.onmessage = event => {
    console.log('Message from server:', event.data);
}

webSocket.onopen = () => {
    console.log('Connected to server');
};

webSocket.onclose = (event) => {
    console.log('Disconnected from server: ', event.code, event.reason);
};

webSocket.onerror = error => {
    console.error('Error:', error);
}

const constraints = { audio: true };
let recorder;

function start() {
    navigator.mediaDevices
        .getUserMedia(constraints)
        .then(mediaStream => {

        // use MediaStream Recording API
        recorder = new MediaRecorder(mediaStream);

        // fires every one second and passes an BlobEvent
        recorder.ondataavailable = event => {
            // get the Blob from the event
            const blob = event.data;

            // and send that blob to the server...
            webSocket.send(blob);
        };

        // make data available event fire every one second
        recorder.start(2000);
    });
}

function stop() {
    recorder.stop();
    webSocket.close(1000, "Finished sending audio");
}

It uses the MediaRecorder API to send an audio chunk every 2 seconds. This is recieved on the backend like this:

main.py:

import asyncio
from io import BytesIO
import websockets

from ASR.ASR import ASR

_ASR = ASR("tiny", "auto","int8")

async def handler(websocket):
    while True:
        try:
            # Receiving binary data directly from the client
            data = await websocket.recv()
            #Handle the audion data with Whisper
            _ASR.process_audio(data)
            # Optionally, send an acknowledgment back to the client
            await websocket.send("Chunk received")
        except websockets.ConnectionClosed:
            print("Connection closed")
            break

# Start WebSocket server
start_server = websockets.serve(handler, "127.0.0.1", 3000)

asyncio.get_event_loop().run_until_complete(start_server)
asyncio.get_event_loop().run_forever()

ASR.py:

from io import BytesIO
import re
from typing import List
from ASR.LocalAgreement import LocalAgreement
from faster_whisper import WhisperModel 
import soundfile as sf

class ASR:
    audio_buffer: BytesIO = BytesIO()
    local_agreement = LocalAgreement()
    context:str = ""
    confirmed_sentences: List[str] = []
    def __init__ (self, model_size: str, device="auto", compute_type = "int8"):
        self.whisper_model = WhisperModel(model_size, device=device, compute_type=compute_type)

    def transcribe(self, audio_buffer: BytesIO, context: str):
        transcribed_text = ""
        segments, info = self.whisper_model.transcribe(audio_buffer)

        for segment in segments:
            transcribed_text += " " + segment.text

        return transcribed_text

    def process_audio(self, audio_chunk) -> str:
        # Append new audio data to the main buffer
        self.audio_buffer.write(audio_chunk)
        self.audio_buffer.seek(0)  # Reset buffer's position to the beginning

        transcribed_text = self.transcribe(self.audio_buffer, self.context)
        print("transcribed_text: " + transcribed_text)
        confirmed_text = self.local_agreement.confirm_tokens(transcribed_text)
        print(confirmed_text)
        punctuation = r"[.!?]"  # Regular expression pattern for ., !, or ?
        # Detect punctuation
        print("check punctuation: ", re.search(punctuation,confirmed_text))
        if re.search(punctuation,confirmed_text):
            split_sentence = re.split(f"({punctuation})", confirmed_text)

            # Join the punctuation back to the respective parts of the sentence
            sentence = [split_sentence[i] + split_sentence[i+1] for i in range(0, len(split_sentence)-1, 2)]

            print("sentence", sentence)
            self.confirmed_sentences.append(sentence[-1])
            self.context = " ".join(self.confirmed_sentences)
            print("context added: " + self.context)

            # Clear the main audio buffer only after processing is complete
            self.audio_buffer = BytesIO()

        return confirmed_text

The issue happens when I try to clear the audio buffer. My thought is to clear the buffer every time I detect a punctuation meaning a sentence has ended. However clearing the buffer throws the following error:

connection handler failed
Traceback (most recent call last):
  File "/Users/frederik/Uni/P7/P7Project/backend/.venv/lib/python3.12/site-packages/websockets/legacy/server.py", line 245, in handler
    await self.ws_handler(self)
  File "/Users/frederik/Uni/P7/P7Project/backend/./src/__main__.py", line 15, in handler
    _ASR.process_audio(data)
  File "/Users/frederik/Uni/P7/P7Project/backend/./src/ASR/ASR.py", line 30, in process_audio
    transcribed_text = self.transcribe(self.audio_buffer, self.context)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/frederik/Uni/P7/P7Project/backend/./src/ASR/ASR.py", line 18, in transcribe
    segments, info = self.whisper_model.transcribe(audio_buffer)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/frederik/Uni/P7/P7Project/backend/.venv/lib/python3.12/site-packages/faster_whisper/transcribe.py", line 319, in transcribe
    audio = decode_audio(audio, sampling_rate=sampling_rate)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/frederik/Uni/P7/P7Project/backend/.venv/lib/python3.12/site-packages/faster_whisper/audio.py", line 46, in decode_audio
    with av.open(input_file, mode="r", metadata_errors="ignore") as container:
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "av/container/core.pyx", line 420, in av.container.core.open
  File "av/container/core.pyx", line 266, in av.container.core.Container.__cinit__
  File "av/container/core.pyx", line 286, in av.container.core.Container.err_check
  File "av/error.pyx", line 326, in av.error.err_check
av.error.InvalidDataError: [Errno 1094995529] Invalid data found when processing input: '<none>'
bakazhou commented 4 weeks ago

Hi Fred,

I was facing same issue, below is my solution:

@app.websocket("/transcribe/streaming")
async def websocket_endpoint(websocket: WebSocket):
    await websocket.accept()
    buffer = b""
    while True:
        out = []
        raw_bytes = await websocket.receive_bytes()
        if not raw_bytes:
            break
        buffer += raw_bytes
        if buffer != b"":
            sf_buffer = soundfile.SoundFile(io.BytesIO(buffer), channels=1, endian="LITTLE", samplerate=SAMPLING_RATE,
                                            subtype="PCM_16", format="RAW")
            audio, _ = librosa.load(sf_buffer, sr=SAMPLING_RATE, dtype=np.float32)
            out.append(audio)
            buffer = b""
        if out:
            audio_data = np.concatenate(out)
            audio_buffer = np.array([], dtype=np.float32)
            audio_buffer = np.append(audio_buffer, audio_data)
            try:
                segments, info = recognize_service.recognize(audio=audio_buffer, beam_size=5, language="en")
                result = {
                    "language": info.language,
                    "language_probability": info.language_probability,
                    "segments": [
                        {
                            "start": segment.start,
                            "end": segment.end,
                            "text": segment.text
                        } for segment in segments
                    ],
                }
                await websocket.send_json(data=result)
            except Exception as e:
                print(e)