etzinis / two_step_mask_learning

A two step optimization for sound source separation on the adaptive front-end domain
64 stars 16 forks source link

about sdr #3

Open KiAlexander opened 4 years ago

KiAlexander commented 4 years ago

I try to test my codes which calculate sdr with your separate samples(ex_18).

In my sdr codes, the result is about 6.47 while yours is 19.37.

can you help me find out anything wrong in my codes? Thx.

the codes are as follows.

`#!/usr/bin/env python

import soundfile as sf
from mir_eval.separation import bss_eval_sources
import numpy as np

import torch

from itertools import permutations

def cal_SDRi(src_ref, src_est, mix):
    # Calculate Source-to-Distortion Ratio improvement (SDRi).
    # NOTE: bss_eval_sources is very very slow.
    # Args:
    #     src_ref: numpy.ndarray, [C, T]
    #     src_est: numpy.ndarray, [C, T], reordered by best PIT permutation
    #     mix: numpy.ndarray, [T]
    # Returns:
    #     average_SDRi

    src_anchor = np.stack([mix, mix], axis=0)
    sdr, sir, sar, popt = bss_eval_sources(src_ref, src_est)
    sdr0, sir0, sar0, popt0 = bss_eval_sources(src_ref, src_anchor)
    avg_SDRi = ((sdr[0]-sdr0[0]) + (sdr[1]-sdr0[1])) / 2
    return avg_SDRi

def cal_SISNRi(src_ref, src_est, mix):
    # Calculate Scale-Invariant Source-to-Noise Ratio improvement (SI-SNRi)
    # Args:
    #     src_ref: numpy.ndarray, [C, T]
    #     src_est: numpy.ndarray, [C, T], reordered by best PIT permutation
    #     mix: numpy.ndarray, [T]
    # Returns:
    #     average_SISNRi
    # 
    sisnr1 = cal_SISNR(src_ref[0], src_est[0])
    sisnr2 = cal_SISNR(src_ref[1], src_est[1])
    sisnr1b = cal_SISNR(src_ref[0], mix)
    sisnr2b = cal_SISNR(src_ref[1], mix)
    avg_SISNRi = ((sisnr1 - sisnr1b) + (sisnr2 - sisnr2b)) / 2
    return avg_SISNRi

def cal_SISNR(ref_sig, out_sig, eps=1e-8):
    # Calcuate Scale-Invariant Source-to-Noise Ratio (SI-SNR)
    # Args:
    #     ref_sig: numpy.ndarray, [T]
    #     out_sig: numpy.ndarray, [T]
    # Returns:
    #     SISNR

    assert len(ref_sig) == len(out_sig)
    ref_sig = ref_sig - np.mean(ref_sig)
    out_sig = out_sig - np.mean(out_sig)
    ref_energy = np.sum(ref_sig ** 2) + eps
    proj = np.sum(ref_sig * out_sig) * ref_sig / ref_energy
    noise = out_sig - proj
    ratio = np.sum(proj ** 2) / (np.sum(noise ** 2) + eps)
    sisnr = 10 * np.log(ratio + eps) / np.log(10.0)
    return sisnr

def calc_sdr(estimation, origin):

    # batch-wise SDR caculation for one audio file.
    # estimation: (batch, nsample)
    # origin: (batch, nsample)

    origin_power = np.sum(origin**2, 1, keepdims=True) + 1e-8  # (batch, 1)

    scale = np.sum(origin*estimation, 1, keepdims=True) / origin_power  # (batch, 1)

    est_true = scale * origin  # (batch, nsample)
    est_res = estimation - est_true  # (batch, nsample)

    true_power = np.sum(est_true**2, 1)
    res_power = np.sum(est_res**2, 1)

    return 10*np.log10(true_power) - 10*np.log10(res_power)  # (batch, 1)

def compute_measures(se,s,j):
    Rss=s.transpose().dot(s)
    this_s=s[:,j]

    a=this_s.transpose().dot(se)/Rss[j,j]
    e_true=a*this_s
    e_res=se-a*this_s
    Sss=np.sum((e_true)**2)
    Snn=np.sum((e_res)**2)

    SDR=10*np.log10(Sss/Snn)

    Rsr= s.transpose().dot(e_res)
    b=np.linalg.inv(Rss).dot(Rsr)

    e_interf = s.dot(b)
    e_artif= e_res-e_interf

    SIR=10*np.log10(Sss/np.sum((e_interf)**2))
    SAR=10*np.log10(Sss/np.sum((e_artif)**2))
    return SDR, SIR, SAR

def GetSDR(se,s):
    se = se.transpose()
    s = s.transpose()

    se=se-np.mean(se,axis=0)
    s=s-np.mean(s,axis=0)
    nsampl,nsrc=se.shape
    nsampl2,nsrc2=s.shape

    assert(nsrc2==nsrc)
    assert(nsampl2==nsampl)

    SDR=np.zeros((nsrc,nsrc))
    SIR=SDR.copy()
    SAR=SDR.copy()

    for jest in range(nsrc):
        for jtrue in range(nsrc):
            SDR[jest,jtrue],SIR[jest,jtrue],SAR[jest,jtrue]=compute_measures(se[:,jest],s,jtrue)

    perm=list(permutations(np.arange(nsrc)))
    nperm=len(perm)
    meanSIR=np.zeros((nperm,))
    for p in range(nperm):
        tp=SIR.transpose().reshape(nsrc*nsrc)
        idx=np.arange(nsrc)*nsrc+list(perm[p])
        meanSIR[p]=np.mean(tp[idx])
    popt=np.argmax(meanSIR)
    per=list(perm[popt])
    idx=np.arange(nsrc)*nsrc+per
    SDR=SDR.transpose().reshape(nsrc*nsrc)[idx]
    SIR=SIR.transpose().reshape(nsrc*nsrc)[idx]
    SAR=SAR.transpose().reshape(nsrc*nsrc)[idx]
    return SDR, SIR, SAR, per

EPS = 1e-8

def cal_si_snr_with_pit(source, estimate_source, source_lengths):
    # Calculate SI-SNR with PIT training.
    # Args:
    #     source: [B, C, T], B is batch size
    #     estimate_source: [B, C, T]
    #     source_lengths: [B], each item is between [0, T]

    assert source.size() == estimate_source.size()
    B, C, T = source.size()

    # Step 1. Zero-mean norm
    num_samples = source_lengths.view(-1, 1, 1).float()  # [B, 1, 1]
    mean_target = torch.sum(source, dim=2, keepdim=True) / num_samples
    mean_estimate = torch.sum(estimate_source, dim=2, keepdim=True) / num_samples
    zero_mean_target = source - mean_target
    zero_mean_estimate = estimate_source - mean_estimate

    # Step 2. SI-SNR with PIT
    # reshape to use broadcast
    s_target = torch.unsqueeze(zero_mean_target, dim=1)  # [B, 1, C, T]
    s_estimate = torch.unsqueeze(zero_mean_estimate, dim=2)  # [B, C, 1, T]
    # s_target = <s', s>s / ||s||^2
    pair_wise_dot = torch.sum(s_estimate * s_target, dim=3, keepdim=True)  # [B, C, C, 1]
    s_target_energy = torch.sum(s_target ** 2, dim=3, keepdim=True) + EPS  # [B, 1, C, 1]
    pair_wise_proj = pair_wise_dot * s_target / s_target_energy  # [B, C, C, T]
    # e_noise = s' - s_target
    e_noise = s_estimate - pair_wise_proj  # [B, C, C, T]
    # SI-SNR = 10 * log_10(||s_target||^2 / ||e_noise||^2)
    pair_wise_si_snr = torch.sum(pair_wise_proj ** 2, dim=3) / (torch.sum(e_noise ** 2, dim=3) + EPS)
    pair_wise_si_snr = 10 * torch.log10(pair_wise_si_snr + EPS)  # [B, C, C]
    print('sisnr:',pair_wise_si_snr)

    # Get max_snr of each utterance
    # permutations, [C!, C]
    perms = source.new_tensor(list(permutations(range(C))), dtype=torch.long)
    # one-hot, [C!, C, C]
    index = torch.unsqueeze(perms, 2)
    perms_one_hot = source.new_zeros((*perms.size(), C)).scatter_(2, index, 1)
    # [B, C!] <- [B, C, C] einsum [C!, C, C], SI-SNR sum of each permutation
    snr_set = torch.einsum('bij,pij->bp', [pair_wise_si_snr, perms_one_hot])
    max_snr_idx = torch.argmax(snr_set, dim=1)  # [B]
    # max_snr = torch.gather(snr_set, 1, max_snr_idx.view(-1, 1))  # [B, 1]
    max_snr, _ = torch.max(snr_set, dim=1, keepdim=True)
    max_snr /= C
    return max_snr, perms, max_snr_idx

def _sdr( y, z, SI=False):
    if SI:
        a = ((z*y).mean(-1) / (y*y).mean(-1)).unsqueeze(-1) * y
        return 10*torch.log10( (a**2).mean(-1) / ((a-z)**2).mean(-1))
    else:
        return 10*torch.log10( (y*y).mean(-1) / ((y-z)**2).mean(-1))

def test():   
    mix = sf.read('./ex_18/mixture.wav')[0]
    source = np.stack([sf.read('./ex_18/s1.wav')[0], sf.read('./ex_18/s2.wav')[0]], axis=0)
    estimate_source = np.stack([sf.read('./ex_18/s1_estimate.wav')[0], sf.read('./ex_18/s2_estimate.wav')[0]], axis=0)

    SDRi =cal_SDRi(source,estimate_source,mix)
    SISNRi = cal_SISNRi(source,estimate_source,mix)
    print('SDRi:{}'.format(SDRi))
    print('SISNRi:{}\n'.format(SISNRi))

    sdr1 = calc_sdr(source, estimate_source)
    sdr2 = calc_sdr(source, np.stack([mix, mix], axis=0))
    sdri = np.mean(sdr1-sdr2)
    print('sdr1:{}'.format(sdr1))
    print('sdr2:{}'.format(sdr2))
    print('sdri:{}\n'.format(sdri))

    SDR, SIR, SAR, per = GetSDR(estimate_source, source)
    print('SDR:{}\nSIR:{}\nSAR:{}\nper:{}\n'.format(SDR, SIR, SAR, per))

    source_lengths = torch.from_numpy(np.array([mix.shape]))
    max_snr, _, _ = cal_si_snr_with_pit(torch.from_numpy(np.array([source])).float(),torch.from_numpy(np.array([estimate_source])).float(),source_lengths)
    print('max_snr:{}\n'.format(max_snr))

    SISDR = _sdr(torch.from_numpy(np.array([source])).float(),torch.from_numpy(np.array([estimate_source])).float(),SI=True)
    print('SISDR: ',SISDR)

if __name__ == '__main__':
    test()

And ouput:

    #     SDRi:7.892910056532607
    #     SISNRi:7.151290758024819

    #     sdr1:[6.68677879 6.25589817]
    #     sdr2:[-1.18234941 -0.17591568]
    #     sdri:7.150471030776565

    #     SDR:[6.68712455 6.25720926]
    #     SIR:[34.92616321 34.92229867]
    #     SAR:[6.69364393 6.26311903]
    #     per:[0, 1]

    #     sisnr: tensor([[[  6.6871, -34.0962],
    #              [-34.1721,   6.2572]]])
    #     max_snr:  tensor([[6.4722]])

    #     SISDR:  tensor([[6.6868, 6.2559]])

    # ex_18/metrics.json
    # {
    # "input_si_sdr": 0.028149127960205078,
    # "input_sdr": 0.15109104033014964,
    # "input_sir": 0.1510910403301708,
    # "input_sar": 144.89122580687916,
    # "input_stoi": 0.7178163832006375,
    # "input_pesq": 1.599277138710022,
    # "si_sdr": 19.083293914794922,
    # "sdr": 19.376235432506704,
    # "sir": 30.187015165321924,
    # "sar": 19.759935974444744,
    # "stoi": 0.9568062920227058,
    # "pesq": 3.562618613243103,
    # "mix_path": "/mnt/data/wham/wav8k/min/tt/mix_clean/050a050c_0.050237_442c020j_-0.050237.wav"
    # }
KiAlexander commented 4 years ago

and the result calculated by pb_bss_eval

{'input_pesq': 1.5960057377815247,
 'input_sar': 11.243911897495014,
 'input_sdr': -0.5471245564855747,
 'input_si_sdr': -0.7163815595640699,
 'input_sir': 0.14099714373415928,
 'input_stoi': 0.662681954751808,
 'pesq': 2.596807837486267,
 'sar': 6.761317115018516,
 'sdr': 6.659954133819946,
 'si_sdr': 5.684012353271584,
 'sir': 23.99486035934128,
 'stoi': 0.8663816779000003}
etzinis commented 4 years ago

Hey thanks for reaching out. I am kind of trying to catch a deadline so I did not check my github issues. It seems that your code is fine. Surprisingly there is a bug probably with the uploaded sources. If you listen to both the estimates and the actual sources you could actually hear that the sources and the mixture sound very noisy. However, the estimation seems to be quite better quality with much less artifacts. Moreover, I have actually used this code to produce some audio examples https://github.com/mpariente/asteroid/tree/master/egs/wham/TwoStep which might also contain this noise because of the wham dataset. I am gonna take a look at this, hopefully in a few days.