adobe-research / MetaAF

Control adaptive filters with neural networks.
https://jmcasebeer.github.io/projects/metaaf
228 stars 38 forks source link

Could not get the same AEC results shown on the demo page with the provided pretrained models #4

Closed fjiang9 closed 2 years ago

fjiang9 commented 2 years ago

Excellent work! Thanks for sharing the code base and pretrained models.

I would like to try the AEC performace of Meta-AF using your pretrained models. To make sure that I use them correctly, I downloaded the wav files of the first double-talk sample on your demo website and ran AEC with the pretrained model _v0.1.0_models/aec/aec_16_dt_c/2022_04_10_15_57_12/epoch230.pkl. However, I can only get much worse AEC result than you provided on the demo website. Could you please help me out? The test code I used:

import os
from aec_eval import get_system_ckpt
import numpy as np
import librosa
import soundfile as sf

ckpt_dir = "v0.1.0_models/aec/"
name = "aec_16_dt_c"
date = "2022_04_10_15_57_12"
epoch = 230

ckpt_loc = os.path.join(ckpt_dir, name, date)

system, kwargs, outer_learnable = get_system_ckpt(
    ckpt_loc,
    epoch,
    model_type="egru",
    system_len=None,
)
fit_infer = system.make_fit_infer(outer_learnable=outer_learnable)
fs = 16000

out_dir = "metaAF_res"
os.makedirs(out_dir, exist_ok=True)

u, _ = librosa.load("u.mp3", sr=fs)
d, _ = librosa.load("d.mp3", sr=fs)
s, _ = librosa.load("s.mp3", sr=fs)
e = d - s

d_input = {"u": u[None, :, None], "d": d[None, :, None],
           "s": s[None, :, None], "e": e[None, :, None]
           }
pred = system.infer({"signals": d_input, "metadata": {}}, fit_infer=fit_infer)[
    0
]
pred = np.array(pred[0, :, 0])

sf.write(os.path.join(out_dir, f"_out.wav"), pred, fs)

Looking forward to hearing from you, thanks! Best, Fei

dongsig commented 2 years ago

Hi, I see the sample audios are compressed in mp3 format for both near & far speed. I am wondering if the mp3 compression will cause a time variable phase change (a bit of like nonlinear affect on the orginal signal), that finally degrades the acoustic echo cancellation score. @fjiang9 Yours Dong

jmcasebeer commented 2 years ago

Hi Fei, thanks for checking out our work. As suggested by Dong, I think the issue is the mp3 compression. I took your code and ran it on outputs from the AEC dataloader as well as on the .mp3 files from my website. The outputs from the dataloader seem correct and sound significantly better. I would recommend you either use the provided dataloader or re-train on .mp3 style compression.

The code I used is below:

name = "aec_16_dt_c"
date = "2022_04_10_15_57_12"
epoch = 230

ckpt_loc = os.path.join(ckpt_dir, name, date)

system, kwargs, outer_learnable = aec_eval.get_system_ckpt(
    ckpt_loc,
    epoch,
    model_type="egru",
    system_len=None,
)
fit_infer = system.make_fit_infer(outer_learnable=outer_learnable)
fs = 16000

# load the mp3 data
u_mp3, _ = librosa.load("./taslp_demos/aec_double_talk/0/u.mp3", sr=fs)
d_mp3, _ = librosa.load("./taslp_demos/aec_double_talk/0/d.mp3", sr=fs)
s_mp3, _ = librosa.load("./taslp_demos/aec_double_talk/0/s.mp3", sr=fs)
e_mp3 = d_mp3 - s_mp3

d_mp3_input = {"u": u_mp3[None, :, None], "d": d_mp3[None, :, None],
           "s": s_mp3[None, :, None], "e": e_mp3[None, :, None]}

# run on mp3 data
pred_mp3 = system.infer({"signals": d_mp3_input, "metadata": {}}, fit_infer=fit_infer)[
    0
]
pred_mp3 = np.array(pred_mp3[0, :, 0])

# load from the dataloader
dset = aec.MSFTAECDataset_RIR(mode='test', double_talk=True, random_roll=True, scene_change=False) 
data = dset[0]
u, d, e, s = (
    data["signals"]["u"],
    data["signals"]["d"],
    data["signals"]["e"],
    data["signals"]["s"],
)
d_input = {"u": u[None], "d": d[None], "s": s[None], "e": e[None]}    

# run on the dataloader
pred = system.infer({"signals": d_input, "metadata": {}}, fit_infer=fit_infer)[
    0
]
pred = np.array(pred[0, :, 0])

display(Audio(pred, rate=fs))
display(Audio(pred_mp3, rate=fs))
fjiang9 commented 2 years ago

@jmcasebeer Thanks for the quick response! I think it is mainly due to the RIR scale mismatch between trainning and test. The RIR data used for training are generally with small values. However, the signals u and d on the demo website seems to be rescaled, which indicates that the actual RIR values used here is much larger than that in the training data. I can get reasonable result by the rescaling d, see my updated test code:

ckpt_dir = "v0.1.0_models/aec/"
name = "aec_16_dt_c"
date = "2022_04_10_15_57_12"
epoch = 230

ckpt_loc = os.path.join(ckpt_dir, name, date)

system, kwargs, outer_learnable = get_system_ckpt(
    ckpt_loc,
    epoch,
    model_type="egru",
    system_len=None,
)
fit_infer = system.make_fit_infer(outer_learnable=outer_learnable)
fs = 16000

out_dir = "metaAF_res"
os.makedirs(out_dir, exist_ok=True)

u, _ = librosa.load("u.mp3", sr=fs)
d, _ = librosa.load("d.mp3", sr=fs)
s, _ = librosa.load("s.mp3", sr=fs)
e = d - s

scale = 10
d = d / scale

d_input = {"u": u[None, :, None], "d": d[None, :, None],
           "s": s[None, :, None], "e": e[None, :, None]
           }
pred = system.infer({"signals": d_input, "metadata": {}}, fit_infer=fit_infer)[
    0
]
pred = np.array(pred[0, :, 0]) * scale

sf.write(os.path.join(out_dir, f"_out.wav"), pred, fs)

I think adding the RIR scale augmentation during training could be helpful.

jmcasebeer commented 2 years ago

Good catch. I had rescaled the demo website .mp3 files to be in [-1,1] for playback. I added a disclaimer about this to the website.

I also agree that training with some data augmentation would be useful.

jmcasebeer commented 2 years ago

Hi, thanks for your interest!

Your question seems unrelated to the issue raised above. Could you please make a new issue and include things like:

Currently, it looks like the AEC challenge dataset was not fully downloaded ... but it is hard to tell without more information.

Thanks!

Alirezanezamdoost commented 2 years ago

I have downloaded the ''AEC-Challenge-main'' dataset and "RIRS_NOISES" dataset and I have set the paths in the config. I would like to try the AEC performance of Meta-AF using your pre-trained models. When I run this code which you tell in the tutorial: (!python /content/MetaAF/zoo/aec/aec.py --n_frames 1 --window_size 2048 --hop_size 1024 --n_in_chan 1 --n_out_chan 1 --is_real --n_devices 1 --batch_size 64 --total_epochs 1000 --val_period 10 --reduce_lr_patience 1 --early_stop_patience 4 --name meta_aec_demo --unroll 16 --extra_signals ude --random_roll --outer_loss log_self_mse --double_talk --dataset nonlinear) I will see 3 choices : wandb: (1) Create a W&B account wandb: (2) Use an existing W&B account wandb: (3) Don't visualize my results If I select 1 and after that register for a W&B account when I enter the wandb backend code. I received from the (wandb.errors.CommError: Permission denied, ask the project owner to grant you access) when I select 3 I will receive (RuntimeError: Error opening '/content/AEC-Challenge main/datasets/synthetic/farend_speech/farend_speech_fileid_0.wav': File contains data in an unknown format.) Could you please help me to solve this error?

jmcasebeer commented 2 years ago

Thanks @Alirezanezamdoost! Lets continue the discussion in a new issue here.

aleksandra-bebe commented 11 months ago

@fjiang9 @jmcasebeer Hello, if you are using pre-trained models, why is it necessary to have a dataset when running the script and when i run this code i get this error can you help me with this ? Exception has occurred: ValueError (note: full exception trace is shown but execution is paused at: _run_module_as_main) 'ElementWiseGRU/~/linear/w' with retrieved shape (4, 32) does not match shape=[5, 32] dtype=dtype('complex64') File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\base.py", line 685, in get_parameter raise ValueError( File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\basic.py", line 179, in call w = hk.get_parameter("w", [input_size, output_size], dtype, init=w_init) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\module.py", line 299, in run_interceptors return bound_method(*args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\contextlib.py", line 81, in inner return func(*args, *kwds) ^^^^^^^^^^^^^^^^^^^ File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\module.py", line 458, in wrapped out = f(args, kwargs) ^^^^^^^^^^^^^^^^^^ File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\basic.py", line 126, in call out = layer(out, *args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\module.py", line 299, in run_interceptors return bound_method(*args, *kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\contextlib.py", line 81, in inner return func(args, kwds) ^^^^^^^^^^^^^^^^^^^ File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\module.py", line 458, in wrapped out = f(*args, kwargs) ^^^^^^^^^^^^^^^^^^ File "C:\Users\pc\Desktop\AI-Beamformers\meta-af\MetaAF\metaaf\optimizer_gru.py", line 71, in preprocess_flatten return self.in_lin(input_stack_flat) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\module.py", line 299, in run_interceptors return bound_method(*args, *kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\contextlib.py", line 81, in inner return func(args, kwds) ^^^^^^^^^^^^^^^^^^^ File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\contextlib.py", line 81, in inner return func(*args, kwds) ^^^^^^^^^^^^^^^^^^^ File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\module.py", line 458, in wrapped out = f(*args, *kwargs) ^^^^^^^^^^^^^^^^^^ File "C:\Users\pc\Desktop\AI-Beamformers\meta-af\MetaAF\metaaf\optimizer_gru.py", line 80, in call rnn_in = self.preprocess_flatten(x, extra_inputs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\module.py", line 299, in run_interceptors return bound_method(args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\contextlib.py", line 81, in inner return func(*args, kwds) ^^^^^^^^^^^^^^^^^^^ File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\module.py", line 458, in wrapped out = f(*args, *kwargs) ^^^^^^^^^^^^^^^^^^ File "C:\Users\pc\Desktop\AI-Beamformers\meta-af\MetaAF\metaaf\optimizer_gru.py", line 122, in _fwd return optimizer(x, h, extra_inputs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\transform.py", line 456, in apply_fn out = f(args, kwargs) File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\transform.py", line 183, in apply_fn out, state = f.apply(params, None, *args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\pc\Desktop\AI-Beamformers\meta-af\MetaAF\metaaf\optimizer_gru.py", line 212, in update update, state = optimizer.apply( ^^^^^^^^^^^^^^^^ File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\jax\example_libraries\optimizers.py", line 199, in tree_update new_states = map(partial(update, i), grad_flat, states) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\pc\Desktop\AI-Beamformers\meta-af\MetaAF\metaaf\core.py", line 462, in online_step opt_s = opt_update(0, filter_features, opt_s) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\pc\Desktop\AI-Beamformers\meta-af\MetaAF\metaaf\core.py", line 549, in fit_single cur_out, loss, batch_state = batch_step( ^^^^^^^^^^^ File "C:\Users\pc\Desktop\AI-Beamformers\meta-af\MetaAF\metaaf\meta.py", line 825, in infer out, aux = fit_infer( ^^^^^^^^^^ File "C:\Users\pc\Desktop\AI-Beamformers\meta-af\MetaAF\zoo\aec\start.py", line 38, in pred = system.infer({"signals": d_input, "metadata": {}}, fit_infer=fit_infer)[0] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\runpy.py", line 88, in _run_code exec(code, run_globals) File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\runpy.py", line 198, in _run_module_as_main (Current frame) return _run_code(code, main_globals, None, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ValueError: 'ElementWiseGRU/~/linear/w' with retrieved shape (4, 32) does not match shape=[5, 32] dtype=dtype('complex64')?