dansuh17 / segan-pytorch

SEGAN pytorch implementation https://arxiv.org/abs/1703.09452
GNU General Public License v3.0
106 stars 32 forks source link

de_emphasis function is broken #17

Closed alessandrobessi closed 5 years ago

alessandrobessi commented 5 years ago

It looks like your de_emphasis function is broken: batch != de_emphasis(pre_emphasis(batch)). Did you test it?

alessandrobessi commented 5 years ago

Here it is a working version:

import numpy as np

def pre_emphasis(batch: np.array, emph_coeff: float = 0.95) -> np.array:
    result = np.concatenate((np.expand_dims(batch[:, 0, 0], axis=2), batch[:, 0, 1:] - emph_coeff
                             * batch[:, 0, :-1]), axis=1)
    return np.expand_dims(result, axis=1)

def de_emphasis(batch: np.array, emph_coeff: float = 0.95) -> np.array:
    result = np.zeros(batch.shape)
    result[:, 0, 0] = batch[:, 0, 0]
    for i in range(batch.shape[2] - 1):
        result[:, 0, i + 1] = batch[:, 0, i + 1] + emph_coeff * result[:, 0, i]

    return result
# test
batch = np.random.random_integers(low=1, high=10, size=(2, 1, 7))
print(batch)
print()
emph_batch = pre_emphasis(batch)
print(emph_batch)
print()
deemph_batch = de_emphasis(emph_batch)
print(deemph_batch)
arijit17 commented 5 years ago

@alessandrobessi this is interesting. Is your code running with the rest of the training?

alessandrobessi commented 5 years ago

@arijit17 it should be consistent with the rest of the code. In my implementation I used pytorch and I changed many things, but as long as a batch goes in and a transformed batch of the same dimension comes out everything should be fine.

arijit17 commented 5 years ago

Currently, if I run the training simply by replacing with your training, there is an error:

Traceback (most recent call last):
  File "model.py", line 378, in <module>
    ref_batch_var, ref_clean_var, ref_noisy_var = split_pair_to_vars(ref_batch_pairs)
  File "model.py", line 340, in split_pair_to_vars
    noisy_batch = np.stack([pair[1].reshape(1, -1) for pair in sample_batch_pair])
  File "model.py", line 340, in <listcomp>
    noisy_batch = np.stack([pair[1].reshape(1, -1) for pair in sample_batch_pair])
alessandrobessi commented 5 years ago

@arijit17 you have to apply those functions to clean_batch and noisy_batch separately, when they are numpy arrays.

arijit17 commented 5 years ago

All right @alessandrobessi

Basically, the bug in the current de-emphasis implementation is because he implemented a FIR filter, instead of an IIR filter.

Here it is verified with MATLAB:

x = [5 10  9  1  7  1  7];
>> pre_emp = filter([1 -0.95],1,x)

pre_emp =

    5.0000    5.2500   -0.5000   -7.5500    6.0500   -5.6500    6.0500

>> wrong_demph = filter([1 0.95],1,pre_emp) %current implementation

wrong_demph =

    5.0000   10.0000    4.4875   -8.0250   -1.1225    0.0975    0.6825

>> correct_demph = filter(1, [1 -0.95],pre_emp) %correct implementation

correct_demph =

     5    10     9     1     7     1     7
dansuh17 commented 5 years ago

Thanks for the correction @alessandrobessi . And right, the version implemented here is indeed an FIR filter @arijit17 . I'll provide a fix soon.

dansuh17 commented 5 years ago

Fixed broken de-emphasis with fc54e81cec52d0141fdd96b4677a4c7252c47bca. Managed to use scipy.signal.lfilter, with the hint given by @arijit17 . Closing.