juanmc2005 / diart

A python package to build AI-powered real-time audio applications
https://diart.readthedocs.io
MIT License
1.1k stars 90 forks source link

Minimizing missed detection #224

Closed arielrado closed 5 months ago

arielrado commented 11 months ago

Iv'e been working on tuning the pipeline for my application which is a real time conversational system, the best results so far are:

36.02% DER, 2.41% false alarm, 20.04% missed detection, and 13.56% confusion

I used my own data which was recorded on the target platform's microphone array and annotated with webrtc vad. our system is reasonably tolerant to false alarms, since we require ASR to lable a speaker, but not to missed detections. Is there a way to make the pipeline less conservative?

I am using segmentation-3.0, default embedding and hyperparameters acheived through tuning : tau=0.493, rho=0.055, delta=0.633. I tried lowering tau to and increasing rho but that just increased the confusion rate. I also tried changing gamma and beta, as i undestood that lowering gamma and beta can yield the desired results but it seems that they have a negligible effect.

juanmc2005 commented 11 months ago

Hi @arielrado,

I used my own data which was recorded on the target platform's microphone array and annotated with webrtc vad

What do you mean "annotated with webrtc vad"? How does that annotation process work? Have you checked that your annotations are good?

arielrado commented 11 months ago

Thanks for responding!

I used the code below, essentially I used webrtc to determine where in a single speaker recording there is speech (in 30ms chunks) and created a wav file containing a few speakers and a rttm file to go with it. I hand checked some of the files to make sure it is correct, there is some error inherit with using a vad but I haven't seen so much of it that it would affect the score.

import os
import random
import numpy as np
from pydub import AudioSegment
import webrtcvad

def detect_speech_segments(audio, vad:webrtcvad.Vad):
    samples = np.array(audio.get_array_of_samples())
    frame_rate = audio.frame_rate
    frame_duration = 30  # in milliseconds
    samples_per_frame = int(audio.frame_rate * frame_duration / 1000.0)

    # Convert to 16-bit PCM format
    samples = samples.astype(np.int16)

    # Perform VAD
    is_speech = []
    for start in range(0, len(samples)-len(samples)%samples_per_frame, samples_per_frame):
        end = min(start + samples_per_frame, len(samples))
        is_speech.append(vad.is_speech(samples[start:end].tobytes(), audio.frame_rate))    
    # Find speech segments
    speech_segments = []
    start_time = 0
    for i, speech_flag in enumerate(is_speech):
        if i == 0 or speech_flag != is_speech[i - 1]:
            if speech_flag:
                start_time = i * frame_duration / 1000
            else:
                end_time = i * frame_duration / 1000
                speech_segments.append((start_time, end_time))

    # Ensure the last segment is complete
    if is_speech[-1]:
        speech_segments.append((start_time, len(samples) / frame_rate))

    return speech_segments

def generate_diarization_data(input_directory, output_directory, num_speakers, output_filename):
    # Get a list of speaker directories
    speaker_directories = [d for d in os.listdir(input_directory) if os.path.isdir(os.path.join(input_directory, d)) and d!='1001b']

    # Select random speakers
    selected_speakers = random.sample(speaker_directories, num_speakers)

    # Initialize variables for RTTM file
    rttm_data = []

    # Create a VAD object
    vad = webrtcvad.Vad()

    audio = AudioSegment.empty()

    # Iterate through selected speakers
    for speaker_id in selected_speakers:
        speaker_path = os.path.join(input_directory, speaker_id)

        # Get a list of wav files for the selected speaker
        wav_files = [f for f in os.listdir(speaker_path) if f.endswith('.wav')]

        # Select random audio segments for the speaker
        selected_segments = random.sample(wav_files, min(2, len(wav_files)))  # You can adjust the number of segments

        # Concatenate selected segments into a single audio file
        combined_audio = AudioSegment.from_wav(os.path.join(speaker_path, selected_segments[0]))

        for segment in selected_segments[1:]:
            segment_audio = AudioSegment.from_wav(os.path.join(speaker_path, segment))
            combined_audio += segment_audio

        # Detect speech segments using VAD
        speech_segments = detect_speech_segments(combined_audio, vad)

        # Generate RTTM entries for the speech segments
        for segment_start, segment_end in speech_segments:
            if segment_start < segment_end:
                rttm_data.append(f"SPEAKER {output_filename} 1 {audio.duration_seconds+segment_start:.3f} {segment_end-segment_start:.3f} <NA> <NA> {speaker_id} <NA>")

        audio+=combined_audio

    # Save the combined audio
    output_path = os.path.join(output_directory, 'wav_data',f"{output_filename}.wav")
    audio.export(output_path, format="wav")

    # Save RTTM file
    rttm_path = os.path.join(output_directory, 'rttm_data',f"{output_filename}.rttm")
    with open(rttm_path, 'w') as rttm_file:
        rttm_file.write('\n'.join(rttm_data))

if __name__ == "__main__":
    input_directory = "/home/non-enc-volume/datasets/noramlized_recordings_output_25-8-2023_0-52_full_db_enh_run_2"
    output_directory = "/home/storage01/diart/rttm_ariel"
    max_speakers = 5 
    samples_per_num_speakers = 100
    output_filename = "generated_diarization_data"
    for num_speakers in range(2,max_speakers+1):
        for i in range(samples_per_num_speakers):
            output_filename = f'{num_speakers}_{i}'
            generate_diarization_data(input_directory, output_directory, num_speakers, output_filename)
juanmc2005 commented 11 months ago

@arielrado do you know what's the detection error rate of your VAD on the same data? i.e. false alarm and missed detection.

I think this may also be related to a deeper problem in diart that I've been meaning to address for some time, which is that tau_active is used to both detect speech (VAD) and speakers that should be identified during clustering.

Consider this: if you increase tau_active you naturally increase missed detection and reduce false alarm, but you're also more conservative in deciding which speakers get identified (and thus used to improve internal speaker representations). Since less speakers satisfy this new tau_active, you have higher chances of having bad or incomplete representations, so your confusion increases.

Conversely, if you decrease tau_active you reduce missed detection and increase false alarm, but now you're too permissive in allowing speaker representations to be changed (this also depends on rho_update to be fair), so your confusion may also go up.

If my intuition is correct, the tuning should find a point in hyper-parameter space where these two scenarios either balance out or one wins over the other, but in any case it will be with the best DER possible. Your tuning doesn't really care if your miss is too high, it just cares about the sum of the three components (you could also try changing the tuning metric to weigh miss detection differently).

This is also a very interesting use case for me to better understand this problem and find a nice solution. I would suggest you first take a look at the miss and false alarm of your VAD. If the performance is better than diart, I'd suggest you use tau_active only as a clustering hyper-parameter and rely on the VAD for speech detection. Then in real-time you should be able to solve discrepancies between the two parallel systems. For example, if the VAD detects speech but diart misses it, what do you do?

Another option would be to implement a different clustering strategy. For example, you could use the overlapping part of two consecutive chunk outputs to decide the speaker permutation of the new prediction and then attempt to detect new speakers with speaker embeddings (big challenge here). You may want to take a look at PR #201 that does something like this without the new speaker detection thing (notice that all hyper-parameters are gone here).

arielrado commented 11 months ago

Hi juanmc2005,

I would suggest you first take a look at the miss and false alarm of your VAD.

I don't know what the precise error rate is on my data since I don't have hand annotations, according to this webrtc has an f1 score of 0.819 on raw audio, which isn't very good. My data was recorded in a quiet envirinmnet so I will use a simple energy level vad instead.

I'd suggest you use tau_active only as a clustering hyper-parameter and rely on the VAD for speech detection.

I don't have a lot of experience working with rxpy, so the codebase is pretty confusing to me. how would you suggest I go about implementing this? what sections should I be looking at? Should I implement the vad as an extra block in the pipeline? I will use a more powerful vad like MarbleNet

I had an idea to use a full offline diarization pipeline instead of the segmentation block. I am using a pretty powerful gpu and I understand that making the pipeline lightweight was a priority, I might be able to make use of the hardware overhead.

thanks for the help so far, I'll keep you posted!

juanmc2005 commented 10 months ago

I don't have a lot of experience working with rxpy, so the codebase is pretty confusing to me. how would you suggest I go about implementing this? what sections should I be looking at? Should I implement the vad as an extra block in the pipeline? I will use a more powerful vad like MarbleNet

You don't need to know rxpy to modify a pipeline, if you look at the implementation of SpeakerDiarization, you'll see that you can just change what happens each time you receive a new audio chunk. The streaming part with rxpy is handled by StreamingInference and you can use it as is.

I had an idea to use a full offline diarization pipeline instead of the segmentation block. I am using a pretty powerful gpu and I understand that making the pipeline lightweight was a priority, I might be able to make use of the hardware overhead.

I don't think using a full offline diarization for each chunk is a good idea, as you'll have the same problems that you have with a segmentation block, for example deciding which speakers of the new chunk correspond to the speakers in previous ones. That part is handled by OnlineSpeakerClustering, and you won't be able to get rid of it with an offline model. At best maybe you'll squeeze a bit more performance, but I don't think it will fundamentally change your problem, but I may be wrong.

To verify if the issue is coming from segmentation or from clustering, you could use the segmentation block as a VAD (you can simply take the max score across all speakers) and measure the false alarm and miss on your data. If the miss is bad you have a segmentation problem, but if it's better then the problem comes from clustering as I hypothesized in my previous message.

You could also consider using pyannote/segmentation-3.0, which is more robust and has better performance.

arielrado commented 10 months ago

Hi I regenrated my test data with pydub.detect_nonsilent instead of webrtc and the missed dection has transferred into false alarm It seems that the high FP in vad caused an artificial increase in FN.

DER FA MD CF
32.52 15.32 2.63 14.57

I can reduce that by adding VAD filtering at the end of the pipeline (using a more reliable VAD) my paln is to only indentify a speaker if the external vad is positive and segmentation has tagged them as active.

I'd like to reduce the confusion, I tried reducing the max-speaker parameter but that caused confusion to increase as opposed to my intuition. suppose I know that there are no more than 3 speakers in a given interaction, is there a better clustering method I can use? I have added a timeout feature in my fork for our use case, the pipeline does not need to keep track of a large number of speakers.

You may want to take a look at PR https://github.com/juanmc2005/diart/pull/201 that does something like this without the new >speaker detection thing (notice that all hyper-parameters are gone here).

I tried out this method but it yields worse results overall.

You could also consider using pyannote/segmentation-3.0, which is more robust and has better performance.

I have been using this method since I found it to preform better when I first started teting diart.

Thanks for all the help so far!

juanmc2005 commented 9 months ago

@arielrado this will be more of a trial and error effort to bring the confusion down. Reducing speaker confusion in real-time applications is still an open problem in the speaker diarization community.

I can reduce that by adding VAD filtering at the end of the pipeline (using a more reliable VAD) my paln is to only indentify a speaker if the external vad is positive and segmentation has tagged them as active.

Well at some point you'll have to decide what you do when the VAD and diarization disagree, for example you can attempt a relabeling using the embedding model if the VAD detects speech that diart missed. It's rather difficult to say what can be done in your setting without in-depth knowledge of your data and concrete application. Optimizing the performance of the library to your particular application is not something trivial.

suppose I know that there are no more than 3 speakers in a given interaction, is there a better clustering method I can use?

You can probably get a better performance by constraining the system to a maximum of 3 speakers. Concerning the clustering method, you could try writing your own or adapting an existing offline algorithm (you could check out pyannote.audio for this). Just be careful about not increasing the latency too much