nkeenan38 / voice_activity_detector

A Voice Activity Detector rust library using the Silero VAD model.
https://crates.io/crates/voice_activity_detector
MIT License
23 stars 6 forks source link

Segmentation fault #37

Open adri1wald opened 3 months ago

adri1wald commented 3 months ago

I've written a test suite and it seg faults every couple of runs with outputs like:

Caused by:
  process didn't exit successfully: `[REDACTED]/target/debug/deps/meeting_bot-d385ad40d7630a91` (signal: 11, SIGSEGV: invalid memory reference)
meeting_bot-d385ad40d7630a91(92412,0x16c927000) malloc: Heap corruption detected, free list is damaged at 0x6000010611c0
*** Incorrect guard value: 188586888663487
meeting_bot-d385ad40d7630a91(92412,0x16c927000) malloc: *** set a breakpoint in malloc_error_break to debug
error: test failed, to rerun pass `--lib`

Caused by:
  process didn't exit successfully: `[REDACTED]/target/debug/deps/meeting_bot-d385ad40d7630a91` (signal: 6, SIGABRT: process abort signal)

I suspect the source of the issue is in ort since there is no unsafe in this crate, but posting here first.

System information

Software:

    System Software Overview:

      System Version: macOS 13.5 (22G74)
      Kernel Version: Darwin 22.6.0
      Boot Volume: Macintosh HD
      Boot Mode: Normal
      Computer Name: REDACTED
      User Name: REDACTED
      Secure Virtual Memory: Enabled
      System Integrity Protection: Enabled
      Time since boot: 2 days, 7 hours, 43 minutes

Hardware:

    Hardware Overview:

      Model Name: MacBook Pro
      Model Identifier: Mac14,6
      Model Number: MNXA3LL/A
      Chip: Apple M2 Max
      Total Number of Cores: 12 (8 performance and 4 efficiency)
      Memory: 64 GB
      System Firmware Version: 8422.141.2
      OS Loader Version: 8422.141.2
      Serial Number (system): REDACTED
      Hardware UUID: REDACTED
      Provisioning UDID: REDACTED
      Activation Lock Status: Enabled

Reproduction below. There's some semantically nonsensical stuff in there if you look closely, but that's because I simplified it while keeping it structurally equivalent from our actual tests.

use anyhow::{Context, Result};
use std::time::Duration;
use voice_activity_detector::{IteratorExt, LabeledAudio, VoiceActivityDetector};

#[derive(Debug, Clone)]
pub(crate) struct SpeechSegment {
    start: Duration,
    end: Duration,
}

impl SpeechSegment {
    pub(crate) fn new(start: Duration, end: Duration) -> Self {
        Self { start, end }
    }

    pub(crate) fn start(&self) -> Duration {
        self.start
    }

    pub(crate) fn end(&self) -> Duration {
        self.end
    }
}

#[derive(Debug, Clone)]
pub(crate) struct DetectionParameters {
    /// The resolution of the analysis
    pub(crate) resolution: Duration,
    /// The threshold for voice detection (0.0 to 1.0)
    pub(crate) threshold: f32,
    /// The number of chunks to pad the detected segments
    pub(crate) padding_chunks: usize,
    /// The minimum gap between segments.
    ///
    /// If the gap between two segments is closer than this, they will be merged.
    ///
    /// If the gap between the start of the audio file and the first segment is closer than this,
    /// the start of the segment will be moved to the start of the audio file.
    ///
    /// If the gap between the end of the audio file and the last segment is closer than this,
    /// the end of the segment will be moved to the end of the audio file.
    pub(crate) min_gap: Duration,
}

impl Default for DetectionParameters {
    fn default() -> Self {
        Self {
            resolution: Duration::from_millis(64),
            threshold: 0.5,
            padding_chunks: 3,
            min_gap: Duration::ZERO,
        }
    }
}

/// Detect speech segments in the given audio data
///
/// Should wrap in tokio::spawn_blocking if called from Tokio runtime
pub(crate) fn detect_speech<R: std::io::Read>(
    reader: hound::WavReader<R>,
    parameters: DetectionParameters,
) -> Result<Vec<SpeechSegment>> {
    let DetectionParameters {
        resolution,
        threshold,
        padding_chunks,
        min_gap,
    } = parameters;

    let sample_rate = reader.spec().sample_rate;
    let chunk_size = (resolution.as_secs_f64() * sample_rate as f64) as usize;

    // The only requirement imposed by the underlying model is the sample
    // rate must be no larger than 31.25 times the chunk size.
    if sample_rate as usize > 32 * chunk_size {
        return Err(anyhow::anyhow!(
            "Sample rate is too high for the given resolution"
        ));
    }

    let vad = VoiceActivityDetector::builder()
        .sample_rate(sample_rate)
        .chunk_size(chunk_size)
        .build()
        .context("Failed to create VAD")?;

    let samples = reader
        .into_samples::<i16>()
        .collect::<Result<Vec<_>, _>>()
        .context("Failed to read samples")?;
    let duration = Duration::from_secs_f64(samples.len() as f64 / sample_rate as f64);

    let labels = samples.into_iter().label(vad, threshold, padding_chunks);

    let mut segments: Vec<SpeechSegment> = Vec::new();
    let mut sample_count = 0usize;

    for label in labels {
        let start = Duration::from_secs_f64(sample_count as f64 / sample_rate as f64);
        let (is_speech, size) = match label {
            LabeledAudio::Speech(chunk) => (true, chunk.len()),
            LabeledAudio::NonSpeech(chunk) => (false, chunk.len()),
        };
        sample_count += size;
        let end = Duration::from_secs_f64(sample_count as f64 / sample_rate as f64);
        if is_speech {
            segments.push(SpeechSegment::new(start, end));
        }
    }

    if segments.is_empty() {
        return Ok(segments);
    }

    let mut merged_segments = Vec::with_capacity(segments.len());
    merged_segments.push(segments[0].clone());

    for current in segments.iter().skip(1) {
        let previous = merged_segments
            .last_mut()
            .expect("we pushed a segment before the loop");
        if current.start() - previous.end() < min_gap {
            previous.end = current.end();
        } else {
            merged_segments.push(current.clone());
        }
    }

    if let Some(first) = merged_segments.first_mut() {
        if first.start() < min_gap {
            first.start = Duration::ZERO;
        }
    }

    if let Some(last) = merged_segments.last_mut() {
        if duration - last.end() < min_gap {
            last.end = duration;
        }
    }

    Ok(merged_segments)
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::{io::Cursor, time::Duration};

    fn generate_silence_audio(duration: Duration, sample_rate: u32) -> Vec<i16> {
        let num_samples = (duration.as_secs_f64() * sample_rate as f64) as usize;
        vec![0i16; num_samples]
    }

    fn generate_test_audio(segments: &[(bool, Duration)]) -> hound::WavReader<Cursor<Vec<u8>>> {
        let mut buffer = std::io::Cursor::new(Vec::new());

        let mut writer = hound::WavWriter::new(
            &mut buffer,
            hound::WavSpec {
                channels: 1,
                sample_rate: 16000,
                bits_per_sample: 16,
                sample_format: hound::SampleFormat::Int,
            },
        )
        .unwrap();

        for (is_speech, duration) in segments {
            if *is_speech {
                for sample in generate_silence_audio(*duration, writer.spec().sample_rate) {
                    writer.write_sample(sample).unwrap()
                }
            } else {
                for sample in generate_silence_audio(*duration, writer.spec().sample_rate) {
                    writer.write_sample(sample).unwrap()
                }
            };
        }

        drop(writer);

        hound::WavReader::new(Cursor::new(buffer.into_inner())).unwrap()
    }

    #[test]
    fn test_no_speech() {
        let audio = generate_test_audio(&[(false, Duration::from_secs(5))]);
        let params = DetectionParameters::default();
        let segments = detect_speech(audio, params).unwrap();
        assert!(segments.is_empty());
    }

    #[test]
    fn test_continuous_speech() {
        let audio = generate_test_audio(&[(true, Duration::from_secs(5))]);
        let params = DetectionParameters::default();
        let segments = detect_speech(audio, params).unwrap();
        assert_eq!(segments.len(), 1);
        assert_eq!(segments[0].start(), Duration::ZERO);
        assert_eq!(segments[0].end(), Duration::from_secs(5));
    }

    #[test]
    fn test_multiple_segments() {
        let audio = generate_test_audio(&[
            (true, Duration::from_secs(1)),
            (false, Duration::from_secs(1)),
            (true, Duration::from_secs(1)),
            (false, Duration::from_secs(1)),
            (true, Duration::from_secs(1)),
        ]);
        let params = DetectionParameters::default();
        let segments = detect_speech(audio, params).unwrap();
        assert_eq!(segments.len(), 3);
    }

    #[test]
    fn test_merge_close_segments() {
        let audio = generate_test_audio(&[
            (true, Duration::from_secs(1)),
            (false, Duration::from_millis(100)),
            (true, Duration::from_secs(1)),
        ]);
        let mut params = DetectionParameters::default();
        params.min_gap = Duration::from_millis(200);
        let segments = detect_speech(audio, params).unwrap();
        assert_eq!(segments.len(), 1);
        assert_eq!(
            segments[0].end() - segments[0].start(),
            Duration::from_millis(2100)
        );
    }

    #[test]
    fn test_boundary_conditions() {
        let audio = generate_test_audio(&[
            (true, Duration::from_millis(500)),
            (false, Duration::from_secs(4)),
            (true, Duration::from_millis(500)),
        ]);
        let params = DetectionParameters::default();
        let segments = detect_speech(audio, params).unwrap();
        assert_eq!(segments.len(), 2);
        assert_eq!(segments[0].start(), Duration::ZERO);
        assert_eq!(segments[1].end(), Duration::from_secs(5));
    }

    #[test]
    fn test_short_audio() {
        let audio = generate_test_audio(&[(true, Duration::from_millis(100))]);
        let params = DetectionParameters::default();
        let segments = detect_speech(audio, params).unwrap();
        dbg!(&segments);
        assert_eq!(segments.len(), 1);
        assert!(false);
    }

    #[test]
    fn test_different_thresholds() {
        let audio = || {
            generate_test_audio(&[
                (true, Duration::from_secs(1)),
                (false, Duration::from_secs(1)),
                (true, Duration::from_secs(1)),
            ])
        };

        let mut params = DetectionParameters::default();
        params.threshold = 0.3;
        let segments_low_threshold = detect_speech(audio(), params.clone()).unwrap();

        params.threshold = 0.7;
        let segments_high_threshold = detect_speech(audio(), params).unwrap();

        assert!(segments_low_threshold.len() >= segments_high_threshold.len());
    }

    #[test]
    fn test_different_resolutions() {
        let audio = || {
            generate_test_audio(&[
                (true, Duration::from_secs(1)),
                (false, Duration::from_secs(1)),
                (true, Duration::from_secs(1)),
            ])
        };

        let mut params = DetectionParameters::default();
        params.resolution = Duration::from_millis(32);
        let segments_high_res = detect_speech(audio(), params.clone()).unwrap();

        params.resolution = Duration::from_millis(128);
        let segments_low_res = detect_speech(audio(), params).unwrap();

        assert!(segments_high_res.len() >= segments_low_res.len());
    }

    #[test]
    fn test_padding_chunks() {
        let audio = || {
            generate_test_audio(&[
                (true, Duration::from_secs(1)),
                (false, Duration::from_secs(1)),
                (true, Duration::from_secs(1)),
            ])
        };

        let mut params = DetectionParameters::default();
        params.padding_chunks = 0;
        let segments_no_padding = detect_speech(audio(), params.clone()).unwrap();

        params.padding_chunks = 5;
        let segments_with_padding = detect_speech(audio(), params).unwrap();

        assert!(
            segments_with_padding[0].end() - segments_with_padding[0].start()
                > segments_no_padding[0].end() - segments_no_padding[0].start()
        );
    }
}
nkeenan38 commented 3 months ago

Can't say I have much here. Definitely no unsafe usage in this crate. Could it be an issue of running out of memory? Have you tried running these tests without parallelization (cargo test -- --test-threads 1)?

nkeenan38 commented 3 months ago

I'm also not too familiar with the internals of ort, but I wonder if using the load-dynamic feature would give more consistent results instead of the default download behavior.

adri1wald commented 3 months ago

Hey @nkeenan38 thanks for responding!

Definitely not running out of memory.

I imagine without parallelisation this will decrease the likelihood of seg faults but in prod this will be part of a web server so actually parallelised tests is closer to reality.

I can try load-dynamic but honestly I'm a bit hesitant to use this crate if it's backed by ort. I tried forked this repo and tried to replace ort with alternatives but so far no success.

Do you think it would make sense to submit this as an issue to the ort repo?