nkeenan38 / voice_activity_detector

A Voice Activity Detector rust library using the Silero VAD model.
https://crates.io/crates/voice_activity_detector
MIT License
26 stars 6 forks source link

Should the VoiceActivityDetector be created on every iteration of a stream in RT contexts? #23

Closed Tails closed 6 months ago

Tails commented 6 months ago

I am using the Rust SDK of LiveKit, which is a WebRTC server. I have a handler like this:

tokio::spawn(async move {
    // Receive the audio frames in a new task
    while let Some(audio_frame) = audio_stream.next().await {
        log::info!("received audio frame - {audio_frame:#?}");

        pub fn chunk_size(af: &AudioFrame) -> usize {
            (af.samples_per_channel
                * af.num_channels
                * std::mem::size_of::<i16>() as u32)
                as usize
        }

        assert_eq!(audio_frame.num_channels, 1, "only mono audio supported");

        // https://crates.io/crates/voice_activity_detector

        // Create a new VoiceActivityDetector with the sample rate of the audio frame
        // TODO: is it expensive to create a new VAD for every audio frame?
        let vad = VoiceActivityDetector::builder()
            .sample_rate(audio_frame.sample_rate)
            .chunk_size(chunk_size(&audio_frame))
            .build()?;

        // This will label any audio chunks with a probability greater than 75% as speech,
        // and label the 3 additional chunks before and after these chunks as speech.
        let labels = audio_frame.data.iter().label(vad, 0.75, 3);
        for label in labels {
            match label {
                LabeledAudio::Speech(_) => println!("speech detected!"),
                LabeledAudio::NonSpeech(_) => println!("non-speech detected!"),
            }
        }
    }
});

Is it intended that the VAD struct is created on every chunk received? Or should this not be a performance issue? I imagine the prediction not working well if the data is lost on every iteration.

I have tried reusing the struct by placing it above the while {}, but the .iter().label(vad) call takes ownership of the struct and consumes it.

nkeenan38 commented 6 months ago

Hi. You are right that the prediction requires the "historic" data. So we need to use the same VAD struct for a single stream of audio.

I think there is just some confusion here with the extension traits for iterators and streams. This was built to work on a long Iterator/Stream of individual samples, in this case i16. There is a bit of an issue using this with the livekit SDK because each item in the stream is a slice of samples instead of a single one. So we need to flatten those before using the label extension method.

I put one solution together here (please excuse my usage of unwrap). This assumes the sample rate is known ahead of time. Also, the chunk size does not have to match the number of samples provided in each slice by livekit. The VAD will buffer as needed, so you can usually just choose one based on the number of samples/duration of audio you'd like to predict each interval.

use futures::StreamExt;
use livekit::webrtc::{audio_frame::AudioFrame, audio_stream::native::NativeAudioStream};
use voice_activity_detector::{LabeledAudio, StreamExt as _, VoiceActivityDetector};

fn audio_stream() -> NativeAudioStream {
    todo!()
}

// We need the type hints from this function to workaround some strange type inference errors
// when this behavior is inlined.
// ```implementation of `std::ops::FnOnce` is not general enough```
// See https://github.com/rust-lang/rust/issues/89976
fn samples(frame: AudioFrame<'_>) -> Vec<i16> {
    frame.data.into_owned()
}

#[tokio::main]
async fn main() {
    tokio::spawn(async move {
        let audio_stream = audio_stream();

        let vad = VoiceActivityDetector::builder()
            .sample_rate(8000)
            .chunk_size(512usize)
            .build()
            .unwrap();

        let mut labels = audio_stream
            .map(samples) // Map the audio frame to an iterator of owned samples
            .map(futures::stream::iter) // Convert the iterator into a stream
            .flatten() // Flatten the stream of streams of samples
            .label(vad, 0.75, 3); // Label the stream of audio samples

        while let Some(label) = labels.next().await {
            match label {
                LabeledAudio::Speech(_) => println!("speech detected!"),
                LabeledAudio::NonSpeech(_) => println!("non-speech detected!"),
            };
        }
    });
}

If the sample rate is dynamic, then you can do something like this to get the first frame and use that to build the VAD.

use futures::StreamExt;
use livekit::webrtc::{audio_frame::AudioFrame, audio_stream::native::NativeAudioStream};
use voice_activity_detector::{LabeledAudio, StreamExt as _, VoiceActivityDetector};

fn audio_stream() -> NativeAudioStream {
    todo!()
}

// We need the type hints from this function to workaround some strange type inference errors
// when this behavior is inlined.
// ```implementation of `std::ops::FnOnce` is not general enough```
// See https://github.com/rust-lang/rust/issues/89976
fn samples(frame: AudioFrame<'_>) -> Vec<i16> {
    frame.data.into_owned()
}

#[tokio::main]
async fn main() {
    tokio::spawn(async move {
        let mut audio_stream = audio_stream();

        let first = audio_stream.next().await.unwrap();
        let vad = VoiceActivityDetector::builder()
            .sample_rate(first.sample_rate)
            .chunk_size(512usize)
            .build()
            .unwrap();

        let mut labels = audio_stream
            .map(samples) // Map the audio frame to an iterator of owned samples
            .map(futures::stream::iter) // Convert the iterator into a stream
            .flatten() // Flatten the stream of streams of samples
            .label(vad, 0.75, 3); // Label the stream of audio samples

        while let Some(label) = labels.next().await {
            match label {
                LabeledAudio::Speech(_) => println!("speech detected!"),
                LabeledAudio::NonSpeech(_) => println!("non-speech detected!"),
            };
        }
    });
}

I hope this helps. Let me know if I can clear up anything else.

Tails commented 6 months ago

Thank you for the quick response, and with such elegant code! This works nicely, great job man!

jason-shen commented 6 months ago

very nice, sorry i do see its a closed issue here, i do have a addon question, i see how its detects its speech, what is the best practice to gather the audiostream/chunks for a speech to text usecase.

thanks very much

nkeenan38 commented 6 months ago

So that would depend on the interface of the speech-text. I put together a few options:

A. Lets assume you are working with a single long-lived websocket connection. You want to filter the audio to only include speech and send that to your transcriber. You can do this, which just removes any non-speech from the stream.

use std::future;

use futures::{Stream, StreamExt};
use voice_activity_detector::{LabeledAudio, StreamExt as _, VoiceActivityDetector};

fn audio_stream() -> impl Stream<Item = i16> {
    futures::stream::iter(vec![])
}

async fn transcribe(mut stream: impl Stream<Item = LabeledAudio<i16>> + Unpin) {
    while let Some(audio) = stream.next().await {
        match audio {
            LabeledAudio::Speech(_) => println!("speech"),
            LabeledAudio::NonSpeech(_) => println!("non-speech"),
        }
    }
}

#[tokio::main]
async fn main() {
    let vad = VoiceActivityDetector::builder()
        .sample_rate(8000)
        .chunk_size(512usize)
        .build()
        .unwrap();
    let filtered = audio_stream()
        .label(vad, 0.75, 3)
        .filter(|audio| future::ready(audio.is_speech()));
    transcribe(filtered).await;
}

B. Some transcribers require the silence/noise to either detect the passage of time or assist in inference. So if we are still working with a single stream, I'd probably just zero-out the nonspeech. This is effectively just de-noising the audio.

use std::future;

use futures::{Stream, StreamExt};
use voice_activity_detector::{LabeledAudio, StreamExt as _, VoiceActivityDetector};

fn audio_stream() -> impl Stream<Item = i16> {
    futures::stream::iter(vec![])
}

async fn transcribe(mut stream: impl Stream<Item = LabeledAudio<i16>> + Unpin) {
    while let Some(audio) = stream.next().await {
        match audio {
            LabeledAudio::Speech(_) => println!("speech"),
            LabeledAudio::NonSpeech(_) => println!("non-speech"),
        }
    }
}

#[tokio::main]
async fn main() {
    let vad = VoiceActivityDetector::builder()
        .sample_rate(8000)
        .chunk_size(512usize)
        .build()
        .unwrap();
    let filtered = audio_stream()
        .label(vad, 0.75, 3)
        .map(|audio| match audio {
            LabeledAudio::Speech(speech) => LabeledAudio::Speech(speech),
            LabeledAudio::NonSpeech(nonspeech) => LabeledAudio::NonSpeech(vec![0; nonspeech.len()]),
        });

    transcribe(filtered).await;
}

C. You want to transcribe distinct chunks of speech. In this case, we want to transform our stream of labeled audio into a multiple streams, each only containing speech. Our strategy here is to take the stream by_ref so that we can consume it in chunks. We first skip all the items in the stream that are not speech and then take items from the stream while they are. Once the speech ends again, that sub-stream will be done. We repeat this until the stream is exhausted. I'll be adding this exact snippet as a test-case because its non-trivial, but likely a common use case.

use std::future;

use futures::{Stream, StreamExt};
use hound::WavSpec;
use voice_activity_detector::{LabeledAudio, StreamExt as _, VoiceActivityDetector};

/// Writes the stream to a file. Returns true if the stream is empty.
async fn write(
    mut stream: impl Stream<Item = LabeledAudio<i16>> + Unpin,
    iteration: usize,
    spec: WavSpec,
) -> Result<bool, Box<dyn std::error::Error>> {
    let filename = format!("tests/.outputs/chunk_stream.{iteration}.wav");
    let mut file = hound::WavWriter::create(filename, spec)?;

    let mut empty = true;
    while let Some(audio) = stream.next().await {
        empty = false;
        for sample in audio {
            file.write_sample(sample)?;
        }
    }

    file.finalize()?;
    Ok(empty)
}

#[tokio::test]
async fn chunk_stream() -> Result<(), Box<dyn std::error::Error>> {
    let mut reader = hound::WavReader::open("tests/samples/sample.wav")?;
    let spec = reader.spec();

    let vad = VoiceActivityDetector::builder()
        .sample_rate(8000)
        .chunk_size(512usize)
        .build()
        .unwrap();

    let chunks = reader.samples::<i16>().map_while(Result::ok);
    let mut labels = tokio_stream::iter(chunks).label(vad, 0.75, 3).fuse();

    for i in 0.. {
        let next = labels
            .by_ref()
            .skip_while(|audio| future::ready(!audio.is_speech()))
            .take_while(|audio| future::ready(audio.is_speech()));

        let empty = write(next, i, spec).await?;
        if empty {
            break;
        }
    }

    Ok(())
}
jason-shen commented 6 months ago

Thank you so much for the detailed reply man, awesome work you done here, you need a donation link man

jason-shen commented 6 months ago

having issues importing StreamExt from the crate though, unresolved import voice_activity_detector::StreamExt consider importing one of these items instead: futures::StreamExt futures_util::StreamExt the item is gated behind the async feature

jason-shen commented 6 months ago

sorry my bad didn't read the docs properly, needs to enable the async flag

jason-shen commented 6 months ago

here is what i came up with


const SAMPLE_RATE: u32 = 16000;
fn process_audio_frame(
    frame: AudioFrame<'_>,
    resampler: &mut audio_resampler::AudioResampler,
    target_num_channels: u32,
    target_sample_rate: u32,
) -> Vec<i16> {
    let data = resampler.remix_and_resample(
        &frame.data,
        frame.samples_per_channel,
        frame.num_channels,
        frame.sample_rate,
        target_num_channels,
        target_sample_rate,
    );

    // Convert the resampled data to i16
    data.iter()
        .map(|&x| (x as f32 * i16::MAX as f32) as i16)
        .collect()
}

pub async fn process_audio_track(
    audio_track: RemoteAudioTrack,
    whisper_transcriber: Arc<Mutex<WhisperModel>>,
) {
    let rtc_track = audio_track.rtc_track();
    let mut resampler = audio_resampler::AudioResampler::default();
    let target_sample_rate = 16000;
    let target_num_channels = 1;
    let audio_stream = NativeAudioStream::new(rtc_track);

    let vad = VoiceActivityDetector::builder()
        .sample_rate(target_sample_rate)
        .chunk_size(512usize)
        .build()
        .unwrap();

    let mut labels = audio_stream
        .map(Box::new(|frame: AudioFrame<'_>| {
            process_audio_frame(
                frame,
                &mut resampler,
                target_num_channels,
                target_sample_rate,
            )
        })
            as Box<dyn FnMut(AudioFrame<'_>) -> Vec<i16> + Send>)
        .map(futures::stream::iter)
        .flatten()
        .label(vad, 0.75, 3);

    while let Some(label) = labels.next().await {
        match label {
            LabeledAudio::Speech(audio_data) => {
                println!("Speech detected with {} samples", audio_data.len());

                tokio::spawn(process_audio_segment(
                    audio_data,
                    whisper_transcriber.clone(),
                ));
            }
            LabeledAudio::NonSpeech(non_speech_data) => {
                println!("Non-speech detected with {} samples", non_speech_data.len());
                // Debugging: Output non-speech data values
                for (i, sample) in non_speech_data.iter().enumerate().take(10) {
                    println!("Non-speech sample {}: {}", i, sample);
                }
            }
        }
    }
} 

not sure if its the right approach, and yeah i am using livekit too, i always seems to hit non-speech even there is speech, any help would be awesome

nkeenan38 commented 6 months ago

So everything is coming back as NonSpeech, is that right? I've found that when I've had that problem, its always been either:

In this case, I think this might be your problem:

    // Convert the resampled data to i16
    data.iter()
        .map(|&x| (x as f32 * i16::MAX as f32) as i16)
        .collect()

When I paste this in my editor, I see that the data is already i16. So this might just be messing up the encoding.

jason-shen commented 6 months ago

yeah alway non-speech that the resample comes back as &[ii16] the label expects a Vec thats what that convert is for, maybe i am doing my convert not correct.

does the other parts looks good to you though

jason-shen commented 6 months ago

i removed that convert use .to_Vec() instead, and added logs like this


LabeledAudio::NonSpeech(non_speech_data) => {
                println!("Non-speech detected with {} samples", non_speech_data.len());
                // Debugging: Output non-speech data values
                for (i, sample) in non_speech_data.iter().enumerate().take(10) {
                    println!("Non-speech sample {}: {}", i, sample);
                }
            }

Non-speech sample 7: -5
Non-speech sample 8: -13
Non-speech sample 9: -23
Non-speech detected with 512 samples
Non-speech sample 0: -11
Non-speech sample 1: -14
Non-speech sample 2: -17
Non-speech sample 3: -16

still alway non-speech

nkeenan38 commented 6 months ago

Yeah your usage of this crate looks good to me. You can remove the boxing here, but that won't affect the functionality at all:

        .map(Box::new(|frame: AudioFrame<'_>| {
            process_audio_frame(
                frame,
                &mut resampler,
                target_num_channels,
                target_sample_rate,
            )
        })
            as Box<dyn FnMut(AudioFrame<'_>) -> Vec<i16> + Send>)

can be:

        .map(|frame: AudioFrame<'_>| {
            process_audio_frame(
                frame,
                &mut resampler,
                target_num_channels,
                target_sample_rate,
            )
        })

I think the issue must be in the resampling. I'd try the approach of listening yourself to the audio thats being sent to the VAD. If that sounds right still, come back here so we can figure it out.

jason-shen commented 6 months ago

sounds good, cheers for the help

jason-shen commented 6 months ago

heres one thing though if the sampleRate is 16000 should the .chunk_size(512usize) still be 512?

nkeenan38 commented 6 months ago

Yeah that's fine. This model was trained on chunk sizes of 256, 512, and 768 samples for a sample rate of 8000. It was trained on chunk sizes of 512, 768, and 1024 samples for a sample rate of 16,000. These are recommended, but not required. The only requirement imposed by the model is the sample rate must be no larger than 31.25 times the chunk size.

tldr; Yeah its alright. Its the smallest recommended chunk size for that sample rate.

jason-shen commented 6 months ago

ok i see what my issue is here, it calls on everytime if its a speech to whisper, i recorded the audio_data, they come in small parts, which i think need to do something like is_speeking is true, push all the chunks, when that becomes false, push all teh chunks to whisper, correct me if i am wrong here, but by looking at your examples i don't think this is need it