sigsep / bsseval

audio source separation evaluation metrics
MIT License
26 stars 10 forks source link

Add SI-SDR #3

Open faroit opened 5 years ago

faroit commented 5 years ago

Add scale invariant SDR metric. See https://arxiv.org/abs/1811.02508

faroit commented 5 years ago

@Jonathan-LeRoux would you be able to contribute Python code or could actively do a code review if we'd need to figure out the code from the paper ourself?

Jonathan-LeRoux commented 5 years ago

I'm happy to help in the most efficient way. I could try to insert our code into the framework, but I'm thinking that, given the simplicity of SI-SDR, giving you our code and discussing how to insert it may be faster and also less risky in terms of introducing bugs. SI-SDR basically involves slightly modified (and much simpler) versions of _bss_decomp_mtifilt and _bss_crit. The main function is this one, in which the decomposition and the criterion computation are merged:

def compute_measures(estimated_signal, reference_signals, j, scaling=True):
    Rss= np.dot(reference_signals.transpose(), reference_signals)
    this_s= reference_signals[:,j]

    if scaling:
        # get the scaling factor for clean sources
        a= np.dot( this_s, estimated_signal) / Rss[j,j]
    else:
        a= 1

    e_true= a * this_s
    e_res= estimated_signal - e_true

    Sss= (e_true**2).sum()
    Snn= (e_res**2).sum()

    SDR= 10 * math.log10(Sss/Snn)

    # Get the SIR
    Rsr= np.dot(reference_signals.transpose(), e_res)
    b= np.linalg.solve(Rss, Rsr)

    e_interf= np.dot(reference_signals , b)
    e_artif= e_res - e_interf

    SIR= 10 * math.log10(Sss / (e_interf**2).sum())
    SAR= 10 * math.log10(Sss / (e_artif**2).sum())

    return SDR, SIR,SAR

I'm hoping that this could be easily inserted within a function that figures out the best permutation.

faroit commented 5 years ago

thanks. Thats impressively simple. Will add this as soon as the refactoring is done

pseeth commented 5 years ago

Happy to give this a shot once the refactor is done! Putting together a PyTorch one should be easy as well. How does this work for multichannel? Do each channel separately and then take the mean? Or report the numbers separately? I don't remember an equivalent of ISR for this measure, also, so that might change the return spec between bsseval and si-sdr.

pseeth commented 5 years ago

I refactored the SI-SDR code for nussl. Once bsseval is mature, I'll strip it from nussl and have nussl depend on bsseval instead. Putting my code here in case it's helpful!

from . import EvaluationBase
from itertools import permutations
import numpy as np

class ScaleInvariantSDR(EvaluationBase):
    def __init__(self, true_sources_list, estimated_sources_list, 
                 compute_permutation=False, source_labels=None, scaling=True):
        self.true_sources_list = true_sources_list
        self.estimated_sources_list = estimated_sources_list
        self.compute_permutation = compute_permutation
        self.scaling = scaling

        if source_labels is None:
            source_labels = []
            for i, x in enumerate(self.true_sources_list):
                if x.path_to_input_file:
                    label = x.path_to_input_file
                else:
                    label = f'source_{i}'
                source_labels.append(label)
        self.source_labels = source_labels
        self.reference_array, self.estimated_array = self._preprocess_sources()

    def evaluate(self):
        num_sources = self.reference_array.shape[-1]
        num_channels = self.reference_array.shape[1]
        orderings = (
            list(permutations(range(num_sources))) 
            if self.compute_permutation 
            else [list(range(num_sources))]
        )
        results = np.empty((len(orderings), num_channels, num_sources, 3))

        for o, order in enumerate(orderings):
            for c in range(num_channels):
                for j in order:
                    SDR, SIR, SAR = self._compute_sdr(
                        self.estimated_array[:, c, j], self.reference_array[:, c, order], j, scaling=self.scaling
                    )
                    results[o, c, j, :] = [SDR, SIR, SAR]
        return self._populate_scores_dict(results, orderings)

    def _populate_scores_dict(self, results, orderings):
        best_permutation_by_sdr = np.argmax(results[:, :, :, 0].mean(axis=1).mean(axis=-1))
        results = results[best_permutation_by_sdr]
        best_permutation = orderings[best_permutation_by_sdr]
        scores = {'permutation': list(best_permutation)}
        for j in best_permutation:
            label = self.source_labels[j]
            scores[label] = {
                metric: results[:, j, m].tolist()
                for m, metric in enumerate(['SDR', 'SIR', 'SAR'])
            }
        return scores

    @staticmethod
    def _compute_sdr(estimated_signal, reference_signals, source_idx, scaling=True):
        references_projection = reference_signals.T @ reference_signals
        source = reference_signals[:, source_idx]
        scale = (source @ estimated_signal) / references_projection[source_idx, source_idx] if scaling else 1

        e_true = scale * source
        e_res = estimated_signal - e_true

        signal = (e_true ** 2).sum()
        noise = (e_res ** 2).sum()
        SDR = 10 * np.log10(signal / noise)

        references_onto_residual = np.dot(reference_signals.transpose(), e_res)
        b = np.linalg.solve(references_projection, references_onto_residual)

        e_interf = np.dot(reference_signals , b)
        e_artif = e_res - e_interf

        SIR = 10 * np.log10(signal / (e_interf**2).sum())
        SAR = 10 * np.log10(signal / (e_artif**2).sum())
        return SDR, SIR, SAR

    def _preprocess_sources(self):
        """
        Prepare the :ref:`audio_data` in the sources. Uses the format:
            (num_samples, num_channels, num_sources)
        Returns:
            (:obj:`np.ndarray`, :obj:`np.ndarray`) reference_source_array, estimated_source_array

        """
        reference_source_array = np.stack([np.copy(x.audio_data.T)
                                            for x in self.true_sources_list], axis=2)
        estimated_source_array = np.stack([np.copy(x.audio_data.T)
                                            for x in self.estimated_sources_list], axis=2)
        reference_source_array -= reference_source_array.mean(axis=0)
        estimated_source_array -= estimated_source_array.mean(axis=0)

        return reference_source_array, estimated_source_array
mpariente commented 4 years ago

Sorry to barge in on this issue but what are the news regarding SI-SDR? I'll need to rewrite my evaluation scripts soon and between speechmetrics, bsseval, mireval and others, the choice is quite hard ;-)

j-paulus commented 4 years ago

This is maybe more an algorithmic question to Jonathan regarding the computation than plain BSSEval, but since the reference implementation is shown here I'm using the opportunity:

The way I'm understanding the algorithm, the computation of the full SI-BSSEval uses first the greedy target scaling (for SI-SDR) and after that MMSE assignment to all references (for SI-SIR and SI-SAR). If the reference sources are uncorrelated, then the j'th entry in the mixing array b should be 0, since the maximal contribution was used in the SI-SDR computation already and the residual should be orthogonal to the j'th reference component. If the reference sources are not fully uncorrelated, the j'th element of b will be non-zero, and some of the j'th source may be assigned to the interferer. If the computation of b would be done without the initial greedy subtraction of the target component, this would be identical to scaling-only BSSEval projection (v2.1).

Assuming the above is correct:

No real question, but trying to check if my understanding of the algorithm is correct.

Jonathan-LeRoux commented 4 years ago

I think you're understanding is correct. SI-SDR implicitly assumes that the references are orthogonal to each other (i.e., uncorrelated). In most cases, the diagonal terms stemming from the correlation will be negligible compared to the norms of the references, so the results will be almost the same.

jvel07 commented 1 year ago

I'm happy to help in the most efficient way. I could try to insert our code into the framework, but I'm thinking that, given the simplicity of SI-SDR, giving you our code and discussing how to insert it may be faster and also less risky in terms of introducing bugs. SI-SDR basically involves slightly modified (and much simpler) versions of _bss_decomp_mtifilt and _bss_crit. The main function is this one, in which the decomposition and the criterion computation are merged:

def compute_measures(estimated_signal, reference_signals, j, scaling=True):
    Rss= np.dot(reference_signals.transpose(), reference_signals)
    this_s= reference_signals[:,j]

    if scaling:
        # get the scaling factor for clean sources
        a= np.dot( this_s, estimated_signal) / Rss[j,j]
    else:
        a= 1

    e_true= a * this_s
    e_res= estimated_signal - e_true

    Sss= (e_true**2).sum()
    Snn= (e_res**2).sum()

    SDR= 10 * math.log10(Sss/Snn)

    # Get the SIR
    Rsr= np.dot(reference_signals.transpose(), e_res)
    b= np.linalg.solve(Rss, Rsr)

    e_interf= np.dot(reference_signals , b)
    e_artif= e_res - e_interf

    SIR= 10 * math.log10(Sss / (e_interf**2).sum())
    SAR= 10 * math.log10(Sss / (e_artif**2).sum())

    return SDR, SIR,SAR

I'm hoping that this could be easily inserted within a function that figures out the best permutation.

Thanks for the code, @Jonathan-LeRoux ! As the function returns SRD, SIR, and SAR, how do you actually compute SI-SDR or SI-SDRi from there? Also, what is the role of j?

Jonathan-LeRoux commented 1 year ago

The code says SDR, SIR, SAR, but it's really SI-SDR, SI-SIR, SI-SAR. SI-SDRi can be computed using the difference between the SI-SDR obtained with the estimate and that obtained with the original mixture as estimate.

jvel07 commented 1 year ago

I see, thanks for the reply @Jonathan-LeRoux :) What are the roles of the arguments j and scaling? Do scaling really impacts the metrics?

Jonathan-LeRoux commented 1 year ago

j determines which reference source to compare with, and scaling switches between SNR and SI-SDR.