fgnt / pb_bss

Collection of EM algorithms for blind source separation of audio signals
MIT License
271 stars 60 forks source link

How to run the example please? #25

Open zcy618 opened 4 years ago

zcy618 commented 4 years ago

dear friend: Could you give one instruction to show how to run the example please? Thanks.

boeddeker commented 4 years ago

Dear zcy618,

do you mean, how you can run the file https://github.com/fgnt/pb_bss/blob/master/examples/mixture_model_example.ipynb ? This is a jupyter notebook. In the documentation for jupyter, you can read, how to launch a notebook server: https://jupyterlab.readthedocs.io/en/stable/

pfeatherstone commented 4 years ago

Is there a simple example for inference?

boeddeker commented 4 years ago

I am not sure, what you mean with inference. The mixture models are categorized as blind source separation technique.

They have no training phase like neuronal networks. They learn the parameters on a single example, and they are applied usually only to this example. Applying it to an independent example will not work, because the spatial properties are different.

pfeatherstone commented 4 years ago

ah ok got you

pfeatherstone commented 4 years ago

Can this be used with a microphone array for example to do source separation?

boeddeker commented 4 years ago

Can this be used with a microphone array for example to do source separation?

Yes, the idea of the mixture models (cACG and complex Watson) is to model spatial differences between the different sources. In https://github.com/fgnt/pb_bss/blob/master/examples/mixture_model_example.ipynb the observation is a multichannel signal with multiple sources at different positions.

pfeatherstone commented 4 years ago

So if all i have is the observation, i.e. the multichannel signal, what are speech_image_0.wav, speech_image_1.wav, ..., noise_image.wav, speech_source_0.wav, speech_source_1.wav, ... ?

pfeatherstone commented 4 years ago

Are the others only used for metrics?

boeddeker commented 4 years ago

So if all i have is the observation, i.e. the multichannel signal, what are speech_image_0.wav, speech_image_1.wav, ..., noise_image.wav, speech_source_0.wav, speech_source_1.wav, ... ?

These files were used to generate the observation.

Are the others only used for metrics?

Yes, inside the notebook we use these signal to get an idea of the performance. So they are only used for metrics and visualization.

pfeatherstone commented 4 years ago

So if i have a multichannel audio file and i want to extract 3 candidate sources, how can i do that? I have this so far:

import  soundfile
import  numpy as np
from    nara_wpe.utils import stft, istft
from    pb_bss.distribution import CACGMMTrainer
from    pb_bss.permutation_alignment import DHTVPermutationAlignment, OraclePermutationAlignment
from    einops import rearrange

def soundfile_read(file):
    data, data_sample_rate = soundfile.read(file)
    return data_sample_rate, np.ascontiguousarray(data.T)

def bss_pb_bss(multichannel_wav):
    samplerate, data = soundfile_read(multichannel_wav)
    Observation = stft(data, 512, 128)
    trainer = CACGMMTrainer()
    Observation_mm = rearrange(Observation, 'd t f -> f t d')
    model = trainer.fit(
        Observation_mm,
        num_classes=3,
        iterations=40,
        inline_permutation_aligner=None
    )
    affiliation = model.predict(Observation_mm)
    pa = DHTVPermutationAlignment.from_stft_size(512)
    affiliation_pa = pa(rearrange(affiliation, 'f k t -> k f t'))
    Speech_image_0_est  = Observation[0, :, :].T * affiliation_pa[0, :, :]
    Speech_image_1_est  = Observation[0, :, :].T * affiliation_pa[1, :, :]
    Noise_image_est     = Observation[0, :, :].T * affiliation_pa[2, :, :]
    speech_image_0_est  = istft(Speech_image_0_est.T, 512, 128)[..., :Observation.shape[-1]]
    speech_image_1_est  = istft(Speech_image_1_est.T, 512, 128)[..., :Observation.shape[-1]]
    noise_image_est     = istft(Noise_image_est.T, 512, 128)[..., :Observation.shape[-1]]
pfeatherstone commented 4 years ago

It looks like OraclePermutationAlignment requires global_pa_reference which requires Noise_image which i don't have.

pfeatherstone commented 4 years ago

do i need to resample speech_image_0_est, speech_image_1_est and noise_image_est ? I think i'm missing something.

boeddeker commented 4 years ago

The OraclePermutationAlignment is used to identify the first speaker, the second speaker and the noise. In a real application you cannot do this. Instead, you need to identify the noise. At the moment we have no code to identify the noise. Usually the noise class is active at the beginning and end of an utterence.

do i need to resample speech_image_0_est, speech_image_1_est and noise_image_est ? I think i'm missing something.

Resampling is not necessary. Maybe you observed, that the noise_image_est is a speaker. The noise has no fixed index.

pfeatherstone commented 4 years ago

So you can only use this algorithm if you know the noise? If you did a recording of the noise profile, would that work?

boeddeker commented 4 years ago

No, it is not necessary to know the noise. You can use a heuristic to identify the noise or you can reduce the number of speakers to the actual number of speakers. We observed a better performance, when the noise is modelled as a separate class, but it also works without.

The heuristic can be using a special initialization, identifying the noise mask or depend on a system before or after the mixture model. It highly depends on the actual application, that you have in mind.

pfeatherstone commented 4 years ago

so i have this:

import  soundfile
import  numpy as np
from    nara_wpe.utils import stft, istft
from    pb_bss.distribution import CACGMMTrainer
from    pb_bss.permutation_alignment import DHTVPermutationAlignment, OraclePermutationAlignment
from    einops import rearrange

def soundfile_read(file):
    data, data_sample_rate = soundfile.read(file)
    return data_sample_rate, np.ascontiguousarray(data.T)

def bss_pb_bss(multichannel_wav, nsources):
    samplerate, data = soundfile_read(multichannel_wav)
    Observation = stft(data, 512, 128)
    trainer = CACGMMTrainer()
    Observation_mm = rearrange(Observation, 'd t f -> f t d')
    model = trainer.fit(
        Observation_mm,
        num_classes=nsources,
        iterations=40,
        inline_permutation_aligner=None
    )
    affiliation = model.predict(Observation_mm)
    pa = DHTVPermutationAlignment.from_stft_size(512)
    affiliation_pa = pa(rearrange(affiliation, 'f k t -> k f t'))

    # global_pa_est = rearrange(affiliation_pa, 'k f (t d) -> k d t f', d=1) * rearrange(Observation,'d t (f k) -> k d t f', k=1)
    # global_pa_est = rearrange(global_pa_est, 'k d t f -> k (d t f)')
    # global_pa_reference = rearrange(np.array([*Speech_image, Noise_image]), 'k d t f -> k (d t f)')
    # global_pa = OraclePermutationAlignment()
    # global_permutation = global_pa.calculate_mapping(global_pa_est, global_pa_reference)
    # global_permutation
    # affiliation_pa = affiliation_pa[global_permutation]

    extracted = [Observation[0, :, :].T * affiliation_pa[i, :, :] for i in range(nsources)]
    extracted = [istft(extracted[i].T, 512, 128)[..., :Observation.shape[-1]] for i in range(nsources)]
    return extracted

But don't know what to do with global_pa_est and global_pa_reference if i don't know Speech_image and Noise_image

boeddeker commented 4 years ago

Since you don't give me any context, it is probably the best for you, to reduce nsources to the actual number of speakers. This will probably degenerate the performance, but for a heuristic to identify the noise you need more knowledge.

Then the extracted signal will be a randomly permuted enhanced signal.

Btw. to get a better enhanced signal, you could use beamforming instead of masking. But I haven't written an example yet.

pfeatherstone commented 4 years ago

The context is , I have a multichannel recording of 3 speakers. I was hoping this repo could attempt to extract those speakers blindly. So yep, nsources here is the number of speakers, in this case 3. But given the example code, i don't know how to set global_pa_est and global_permutation if they are required.

boeddeker commented 4 years ago

I personally prefer to use nsources=4 in this case. But since it looks like you have an application in mine, where you don't have the source signals (i.e. no academic task). Using in such a situation nsources=4 requires much more knowledge. For example, you have a voice activity detector before the source separation, you have a very good SNR or the followup system can identify the noise.

Given this sparse information, you could choose nsources=3 and don't care about the global permutation. But keep in mind, that the followup system has to handle an arbitrary speaker permutation. No algorithm can solve the global permutation, given just the observation.

pfeatherstone commented 4 years ago

yep that's fine. So to recap, for nsources==3, it's ok not to have the noise, and the following code is ok ?

def bss_pb_bss(multichannel_wav, nsources):
    samplerate, data = soundfile_read(multichannel_wav)
    Observation = stft(data, 512, 128)
    trainer = CACGMMTrainer()
    Observation_mm = rearrange(Observation, 'd t f -> f t d')
    model = trainer.fit(
        Observation_mm,
        num_classes=nsources,
        iterations=40,
        inline_permutation_aligner=None
    )
    affiliation = model.predict(Observation_mm)
    pa = DHTVPermutationAlignment.from_stft_size(512)
    affiliation_pa = pa(rearrange(affiliation, 'f k t -> k f t'))

    # global_pa_est = rearrange(affiliation_pa, 'k f (t d) -> k d t f', d=1) * rearrange(Observation,'d t (f k) -> k d t f', k=1)
    # global_pa_est = rearrange(global_pa_est, 'k d t f -> k (d t f)')
    # global_pa_reference = rearrange(np.array([*Speech_image, Noise_image]), 'k d t f -> k (d t f)')
    # global_pa = OraclePermutationAlignment()
    # global_permutation = global_pa.calculate_mapping(global_pa_est, global_pa_reference)
    # affiliation_pa = affiliation_pa[global_permutation]

    extracted = [Observation[0, :, :].T * affiliation_pa[i, :, :] for i in range(nsources)]
    extracted = [istft(extracted[i].T, 512, 128)[..., :Observation.shape[-1]] for i in range(nsources)]
    return extracted

However, for me, data has shape (4,395264) (so 4 input channels and 395264 samples each) and I would expect extracted to be a list of 3 numpy arrays of shape (1,395264) or (395264). However, using this code, extracted is a list of numpy arays of shape (257). I don't understand the shapes. I think some ops are missing.

boeddeker commented 4 years ago

yep that's fine. So to recap, for nsources==3, it's ok not to have the noise

Yes. If the SNR is ok and the noise is not a point source.

However, for me, data has shape (4,395264) (so 4 input channels and 395264 samples each) and I would expect extracted to be a list of 3 numpy arrays of shape (1,395264) or (395264). However, using this code, extracted is a list of numpy arays of shape (257). I don't understand the shapes. I think some ops are missing.

That is the reason, why we created the notebook with a visualization of the spectrum. Try to visualize the intermediate signals and take a look at the shapes. 257 is the number of frequencies. The istft need as input the shape ... frames frequencies. You can take a look at the docstring of our functions. The istft hat a num_samples argument and when you change istft(extracted[i].T, 512, 128)[..., :Observation.shape[-1]] to istft(extracted[i].T, 512, 128, num_samples=Observation.shape[-1]) it should complain, that 257 is too low for an output sample length of 1_395_XXX.

pfeatherstone commented 4 years ago

The function definition is

def istft(
        stft_signal,
        size=1024,
        shift=256,
        window=signal.blackman,
        fading=True,
        window_length=None,
        symmetric_window=False,
):

I can't see num_samples. Sorry i'm not trying to be a pain. I just want to see if it's possible to do blind source separation using this repo.

boeddeker commented 4 years ago

Sorry, I forgot, that we have two versions online. In version in nara_wpe is slightly older. The function paderbox.transform.module_stft.istft does the same, but recently we added the num_samples argument.

Nevertheless, the problem in your code is, that Observation.shape[-1] is 257, while you want to use data.shape[-1].

pfeatherstone commented 4 years ago

Yep , it seems to work fine. Thank you very much @boeddeker. Should have said this a couple hours ago.