fakufaku / fast_bss_eval

A fast implementation of bss_eval metrics for blind source separation
https://fast-bss-eval.readthedocs.io/en/latest/
MIT License
130 stars 8 forks source link

ValueError: einstein sum subscripts string contains too many subscripts for operand 0 #11

Closed Shin-ichi-Takayama closed 2 years ago

Shin-ichi-Takayama commented 2 years ago

Hello. I ran the following Python code with the sample code as a reference.

from scipy.io import wavfile  
import fast_bss_eval  

fs, ref = wavfile.read("./test/ref.wav")  
_,  est = wavfile.read("./test/est.wav")  

sdr, sir, sar, perm = fast_bss_eval.bss_eval_sources(ref.T, est.T)  

However, the following errors occurred

(fast_bss_eval) C:\Users\4020737\Documents\git\FastBssEval>python eval.py
C:\Users\4020737\Documents\git\FastBssEval\eval.py:4: WavFileWarning: Chunk (non-data) not understood, skipping it.
  fs, ref = wavfile.read("./test/ref.wav")
C:\Users\4020737\Documents\git\FastBssEval\eval.py:5: WavFileWarning: Chunk (non-data) not understood, skipping it.
  _,  est = wavfile.read("./test/est.wav")
Traceback (most recent call last):
  File "C:\Users\4020737\Documents\git\FastBssEval\eval.py", line 8, in <module>
    sdr, sir, sar, perm = fast_bss_eval.bss_eval_sources(ref.T, est.T)
  File "C:\Users\4020737\Anaconda3\envs\fast_bss_eval\lib\site-packages\fast_bss_eval\__init__.py", line 365, in bss_eval_sources
    return _dispatch_backend(
  File "C:\Users\4020737\Anaconda3\envs\fast_bss_eval\lib\site-packages\fast_bss_eval\__init__.py", line 304, in _dispatch_backend
    return f_numpy(*args, **kwargs)
  File "C:\Users\4020737\Anaconda3\envs\fast_bss_eval\lib\site-packages\fast_bss_eval\numpy\metrics.py", line 657, in bss_eval_sources
    coh_sdr, coh_sar = square_cosine_metrics(
  File "C:\Users\4020737\Anaconda3\envs\fast_bss_eval\lib\site-packages\fast_bss_eval\numpy\metrics.py", line 522, in square_cosine_metrics
    acf, xcorr = compute_stats_2(ref, est, length=filter_length)
  File "C:\Users\4020737\Anaconda3\envs\fast_bss_eval\lib\site-packages\fast_bss_eval\numpy\metrics.py", line 173, in compute_stats_2
    prod = np.einsum("...cn,...dn->...ncd", X, X.conj())
  File "<__array_function__ internals>", line 180, in einsum
  File "C:\Users\4020737\Anaconda3\envs\fast_bss_eval\lib\site-packages\numpy\core\einsumfunc.py", line 1359, in einsum
    return c_einsum(*operands, **kwargs)
ValueError: einstein sum subscripts string contains too many subscripts for operand 0

I thought the wav file was not good and modified the code as follows, but result was the same.

from scipy.io import wavfile
import numpy as np
import fast_bss_eval

ref = np.random.randint(1000, 10000, 160000)
est = np.random.randint(1000, 10000, 160000)

#compute the metrics
sdr, sir, sar, perm = fast_bss_eval.bss_eval_sources(ref.T, est.T)

The list of libraries in my environment is as follows.

# packages in environment at C:\Users\4020737\Anaconda3\envs\fast_bss_eval:
#
# Name                    Version                   Build  Channel
blas                      1.0                         mkl
bzip2                     1.0.8                he774522_0
ca-certificates           2022.4.26            haa95532_0
certifi                   2022.5.18.1     py310haa95532_0
fast-bss-eval             0.1.4                      py_0    wietsedv
icc_rt                    2019.0.0             h0cc432a_1
intel-openmp              2021.4.0          haa95532_3556
libffi                    3.4.2                hd77b12b_4
mkl                       2021.4.0           haa95532_640
mkl-service               2.4.0           py310h2bbff1b_0
mkl_fft                   1.3.1           py310ha0764ea_0
mkl_random                1.2.2           py310h4ed8f06_0
numpy                     1.22.3          py310h6d2d95c_0
numpy-base                1.22.3          py310h206c741_0
openssl                   1.1.1o               h2bbff1b_0
pip                       21.2.4          py310haa95532_0
python                    3.10.4               hbb2ffb3_0
scipy                     1.7.3           py310h6d2d95c_0
setuptools                61.2.0          py310haa95532_0
six                       1.16.0             pyhd3eb1b0_1
sqlite                    3.38.3               h2bbff1b_0
tk                        8.6.12               h2bbff1b_0
tzdata                    2022a                hda174b7_0
vc                        14.2                 h21ff451_1
vs2015_runtime            14.27.29016          h5e58377_2
wheel                     0.37.1             pyhd3eb1b0_0
wincertstore              0.2             py310haa95532_2
xz                        5.2.5                h8cc25b3_1
zlib                      1.2.12               h8cc25b3_2

Best regards.

fakufaku commented 2 years ago

Are the wav files single channel ? The library requires a channel dimension to be present. Can you try modify you code as follows

ref = ref[None, ...]
est = est[None, ...]
sdr, sir, sar, perm = fast_bss_eval.bss_eval_sources(ref, est)  
Shin-ichi-Takayama commented 2 years ago

Thank you for your response.

Are the wav files single channel ?

Yes, it is a single channel wav file.

I ran the following code.

from scipy.io import wavfile
import numpy as np
import fast_bss_eval

fs, ref = wavfile.read("./test/ref.wav")
_,  est = wavfile.read("./test/est.wav")

ref = ref[None, ...]
est = est[None, ...]

sdr, sir, sar, perm = fast_bss_eval.bss_eval_sources(ref, est)

However, I got a new error.

(fast_bss_eval) C:\Users\4020737\Documents\git\FastBssEval>python eval.py
C:\Users\4020737\Documents\git\FastBssEval\eval.py:5: WavFileWarning: Chunk (non-data) not understood, skipping it.
  fs, ref = wavfile.read("./test/ref.wav")
C:\Users\4020737\Documents\git\FastBssEval\eval.py:6: WavFileWarning: Chunk (non-data) not understood, skipping it.
  _,  est = wavfile.read("./test/est.wav")
C:\Users\4020737\Anaconda3\envs\fast_bss_eval\lib\site-packages\fast_bss_eval\numpy\helpers.py:69: RuntimeWarning: divide by zero encountered in log10
  return 10.0 * np.log10(ratio)
Traceback (most recent call last):
  File "C:\Users\4020737\Documents\git\FastBssEval\eval.py", line 11, in <module>
    sdr, sir, sar, perm = fast_bss_eval.bss_eval_sources(ref, est)
  File "C:\Users\4020737\Anaconda3\envs\fast_bss_eval\lib\site-packages\fast_bss_eval\__init__.py", line 365, in bss_eval_sources
    return _dispatch_backend(
  File "C:\Users\4020737\Anaconda3\envs\fast_bss_eval\lib\site-packages\fast_bss_eval\__init__.py", line 304, in _dispatch_backend
    return f_numpy(*args, **kwargs)
  File "C:\Users\4020737\Anaconda3\envs\fast_bss_eval\lib\site-packages\fast_bss_eval\numpy\metrics.py", line 672, in bss_eval_sources
    neg_sir, neg_sdr, neg_sar, perm = _solve_permutation(
  File "C:\Users\4020737\Anaconda3\envs\fast_bss_eval\lib\site-packages\fast_bss_eval\numpy\helpers.py", line 100, in _solve_permutation
    dum, p_opt = _linear_sum_assignment_with_inf(loss_mat[m])
  File "C:\Users\4020737\Anaconda3\envs\fast_bss_eval\lib\site-packages\fast_bss_eval\numpy\helpers.py", line 136, in _linear_sum_assignment_with_inf
    m = values.min()
  File "C:\Users\4020737\Anaconda3\envs\fast_bss_eval\lib\site-packages\numpy\core\_methods.py", line 44, in _amin
    return umr_minimum(a, axis, None, out, keepdims, initial, where)
ValueError: zero-size array to reduction operation minimum which has no identity

As a test, I ran the following code.

from scipy.io import wavfile
import numpy as np
import fast_bss_eval

ref = np.random.randint(1000, 10000, 16000)
est = np.random.randint(1000, 10000, 16000)

ref = ref[None, ...]
est = est[None, ...]

sdr, sir, sar, perm = fast_bss_eval.bss_eval_sources(ref, est)

print('sdr:', sdr)
print('sir:', sir)
print('sar:', sar)
print('perm:', perm)

Then, there were cases where the results were displayed without errors, and other cases where errors occurred. Errors occurred 5 out of 10 times.

fakufaku commented 2 years ago

Thanks for reporting this too! So the problem here is that in the case of only one channel, the SIR is infinite and the permutation solver was not robust enough. Since you have only one source, you can specify compute_permutation=False, which should fix the issue for now. I will also add the necessary checks to avoid such problems in the future.

fakufaku commented 2 years ago

Also note that if you work with pytorch, providing interger valued signals will result in some error too. It may be preferable to work with floating point values. The numpy version seems fine though.

Shin-ichi-Takayama commented 2 years ago

Thanks for your reply, I appreciate it.

I have run the following code.

from scipy.io import wavfile
import numpy as np
import fast_bss_eval

fs, ref = wavfile.read("./test/ref.wav")
_,  est = wavfile.read("./test/est.wav")

ref = ref[None, ...]
est = est[None, ...]

sdr, sir, sar, perm = fast_bss_eval.bss_eval_sources(ref, est, compute_permutation=False)

print('sdr:', sdr)
print('sir:', sir)
print('sar:', sar)
print('perm:', perm)

Then I got the following error message.

(fast_bss_eval) C:\Users\4020737\Documents\git\FastBssEval>python eval.py
C:\Users\4020737\Documents\git\FastBssEval\eval.py:5: WavFileWarning: Chunk (non-data) not understood, skipping it.
  fs, ref = wavfile.read("./test/ref.wav")
C:\Users\4020737\Documents\git\FastBssEval\eval.py:6: WavFileWarning: Chunk (non-data) not understood, skipping it.
  _,  est = wavfile.read("./test/est.wav")
C:\Users\4020737\Anaconda3\envs\fast_bss_eval\lib\site-packages\fast_bss_eval\numpy\helpers.py:69: RuntimeWarning: divide by zero encountered in log10
  return 10.0 * np.log10(ratio)
Traceback (most recent call last):
  File "C:\Users\4020737\Documents\git\FastBssEval\eval.py", line 11, in <module>
    sdr, sir, sar, perm = fast_bss_eval.bss_eval_sources(ref, est, compute_permutation=False)
ValueError: not enough values to unpack (expected 4, got 3)

I would like to use this great tool and would appreciate your support. Best regards.

fakufaku commented 2 years ago

If you have compute_permutation=False, then perm is not returned.

sdr, sir, sar = fast_bss_eval.bss_eval_sources(ref, est, compute_permutation=False)

should work.

Shin-ichi-Takayama commented 2 years ago

Thank you for your response. I was able to load a single channel audio file and confirm that SDR, SIR, and SAR are calculated.

(fast_bss_eval) C:\Users\4020737\Documents\git\FastBssEval>python eval.py
C:\Users\4020737\Documents\git\FastBssEval\eval.py:5: WavFileWarning: Chunk (non-data) not understood, skipping it.
  fs, ref = wavfile.read("./test/ref.wav")
C:\Users\4020737\Documents\git\FastBssEval\eval.py:6: WavFileWarning: Chunk (non-data) not understood, skipping it.
  _,  est = wavfile.read("./test/est.wav")
sdr: [13.38897225]
sir: [146.75836169]
sar: [13.38897225]

I have three questions.

  1. If the audio file contains silence, does SIR output inf? I am getting the following warning.

    C:\Users\4020737\Anaconda3\envs\fast_bss_eval\lib\site-packages\fast_bss_eval\numpy\helpers.py:69: RuntimeWarning: divide by zero encountered in log10 return 10.0 * np.log10(ratio) sdr: [13.86747017] sir: [inf] sar: [13.86747017]

    When only the audio segment was used, the value was output.

  2. The SIR outputs a value of 146. Is this value reasonable?

  3. What do ref and est mean? Does ref mean voice only file? Does est mean the file after processing?

fakufaku commented 2 years ago
  1. Not necessarily. However, when there is only one channel, then the SIR is infinite, because there is basically no interference (SIR measures the ratio of target to interference).
  2. Yes, this is reasonable. Here the value 146 dB is due to numerical imprecision (it should be inf).
  3. References ref are usually the clean target signals. The est signals are their estimates obtained from a noisy mixture. You can check the doc or section 2 of the paper for a precise definition.
Shin-ichi-Takayama commented 2 years ago

Thank you for your response. I will refer to your paper.