adobe-research / MetaAF

Control adaptive filters with neural networks.
https://jmcasebeer.github.io/projects/metaaf
218 stars 38 forks source link

with retrieved shape (4, 32) does not match shape=[5, 32] dtype=dtype('complex64') #20

Closed aleksandra-bebe closed 7 months ago

aleksandra-bebe commented 8 months ago

Hello, when i run code from second closed issue where is used pre-trained models and aec i get this error can you help me with this ? Exception has occurred: ValueError (note: full exception trace is shown but execution is paused at: _run_module_as_main) 'ElementWiseGRU//linear/w' with retrieved shape (4, 32) does not match shape=[5, 32] dtype=dtype('complex64') File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\base.py", line 685, in get_parameter raise ValueError( File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\basic.py", line 179, in call w = hk.get_parameter("w", [input_size, output_size], dtype, init=w_init) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\module.py", line 299, in run_interceptors return bound_method(*args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\contextlib.py", line 81, in inner return func(*args, *kwds) ^^^^^^^^^^^^^^^^^^^ File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\module.py", line 458, in wrapped out = f(args, kwargs) ^^^^^^^^^^^^^^^^^^ File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\basic.py", line 126, in call out = layer(out, *args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\module.py", line 299, in run_interceptors return bound_method(*args, *kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\contextlib.py", line 81, in inner return func(args, kwds) ^^^^^^^^^^^^^^^^^^^ File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\module.py", line 458, in wrapped out = f(*args, kwargs) ^^^^^^^^^^^^^^^^^^ File "C:\Users\pc\Desktop\AI-Beamformers\meta-af\MetaAF\metaaf\optimizer_gru.py", line 71, in preprocess_flatten return self.in_lin(input_stack_flat) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\module.py", line 299, in run_interceptors return bound_method(*args, *kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\contextlib.py", line 81, in inner return func(args, kwds) ^^^^^^^^^^^^^^^^^^^ File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\contextlib.py", line 81, in inner return func(*args, kwds) ^^^^^^^^^^^^^^^^^^^ File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\module.py", line 458, in wrapped out = f(*args, *kwargs) ^^^^^^^^^^^^^^^^^^ File "C:\Users\pc\Desktop\AI-Beamformers\meta-af\MetaAF\metaaf\optimizer_gru.py", line 80, in call rnn_in = self.preprocess_flatten(x, extra_inputs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\module.py", line 299, in run_interceptors return bound_method(args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\contextlib.py", line 81, in inner return func(*args, kwds) ^^^^^^^^^^^^^^^^^^^ File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\module.py", line 458, in wrapped out = f(*args, *kwargs) ^^^^^^^^^^^^^^^^^^ File "C:\Users\pc\Desktop\AI-Beamformers\meta-af\MetaAF\metaaf\optimizer_gru.py", line 122, in _fwd return optimizer(x, h, extra_inputs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\transform.py", line 456, in apply_fn out = f(args, kwargs) File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\haiku_src\transform.py", line 183, in apply_fn out, state = f.apply(params, None, *args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\pc\Desktop\AI-Beamformers\meta-af\MetaAF\metaaf\optimizer_gru.py", line 212, in update update, state = optimizer.apply( ^^^^^^^^^^^^^^^^ File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\site-packages\jax\example_libraries\optimizers.py", line 199, in tree_update new_states = map(partial(update, i), grad_flat, states) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\pc\Desktop\AI-Beamformers\meta-af\MetaAF\metaaf\core.py", line 462, in online_step opt_s = opt_update(0, filter_features, opt_s) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\pc\Desktop\AI-Beamformers\meta-af\MetaAF\metaaf\core.py", line 549, in fit_single cur_out, loss, batch_state = batch_step( ^^^^^^^^^^^ File "C:\Users\pc\Desktop\AI-Beamformers\meta-af\MetaAF\metaaf\meta.py", line 825, in infer out, aux = fit_infer( ^^^^^^^^^^ File "C:\Users\pc\Desktop\AI-Beamformers\meta-af\MetaAF\zoo\aec\start.py", line 38, in pred = system.infer({"signals": d_input, "metadata": {}}, fit_infer=fit_infer)[0] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\runpy.py", line 88, in _run_code exec(code, run_globals) File "C:\Users\pc\AppData\Local\Programs\Python\Python311\Lib\runpy.py", line 198, in _run_module_as_main (Current frame) return _run_code(code, main_globals, None, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ValueError: 'ElementWiseGRU//linear/w' with retrieved shape (4, 32) does not match shape=[5, 32] dtype=dtype('complex64')?

jmcasebeer commented 8 months ago

Hello, and thanks for the interest. Could you please share the exact code you ran, and describe what setup you did before that?

aleksandra-bebe commented 8 months ago

@jmcasebeer I downloaded pre-trained models from the link and provided the path to the AEC model. Additionally, I downloaded the AEC challenge dataset and specified its path in the config.py script." and on this line of code pred = system.infer({"signals": d_input, "metadata": {}}, fit_infer=fit_infer)[0] give me issue This is the code that I ran. import os from aec_eval import get_system_ckpt import numpy as np import librosa import soundfile as sf import aec

ckpt_dir = r"C:\Users\pc\Desktop\AI-Beamformers\meta-af\MetaAF\v0.1.0_models\aec" name = "aec_16_dt_c" date = "2022_04_10_15_57_12" epoch = 230

ckpt_loc = os.path.join(ckpt_dir, name, date)

system, kwargs, outer_learnable = get_system_ckpt( ckpt_loc, epoch, model_type="egru", system_len=None, ) fit_infer = system.make_fit_infer(outer_learnable=outer_learnable) fs = 16000

out_dir = "metaAF_res" os.makedirs(out_dir, exist_ok=True)

u, = librosa.load(r"C:\Users\pc\Desktop\AI-Beamformers\meta-af\MetaAF\zoo\aec\u.wav", sr=fs) d, = librosa.load(r"C:\Users\pc\Desktop\AI-Beamformers\meta-af\MetaAF\zoo\aec\d.wav", sr=fs) s, _ = librosa.load(r"C:\Users\pc\Desktop\AI-Beamformers\meta-af\MetaAF\zoo\aec\s.wav", sr=fs) e = d - s d_mp3_input = {"u": u[None, :, None], "d": d[None, :, None], "s": s[None, :, None], "e": e[None, :, None]}

pred_mp3 = system.infer({"signals": d_mp3_input, "metadata": {}}, fit_infer=fit_infer)[ 0 ] pred_mp3 = np.array(pred_mp3[0, :, 0])

dset = aec.MSFTAECDataset_RIR(mode='test', double_talk=True, random_roll=True, scene_change=False) data = dset[0] u, d, e, s = ( data["signals"]["u"], data["signals"]["d"], data["signals"]["e"], data["signals"]["s"], ) d_input = {"u": u[None], "d": d[None], "s": s[None], "e": e[None]}

pred = system.infer({"signals": d_input, "metadata": {}}, fit_infer=fit_infer)[ 0 ] pred = np.array(pred[0, :, 0])

sf.write(os.path.join(out_dir, f"_out.wav"), pred, fs)

aleksandra-bebe commented 7 months ago

@jmcasebeer
Hello! I hope you had a wonderful holidays. Do you have any solution regarding my issue?

jmcasebeer commented 7 months ago

Could you please try using the most recent model weights? It looks like you're using the 0.1.0 weights with the 1.0.1 code. You can get the 1.0.1 weights here. Then, update the checkpoint loading to use a more recent checkpoint.

For example:

ckpt_dir = "v1.0.1_models/aec/"
name = "meta_aec_16_combo_rl_4_1024_512_r2"
date = "2022_10_19_23_43_22"
epoch = 110

ckpt_loc = os.path.join(ckpt_dir, name, date)

system, kwargs, outer_learnable = get_system_ckpt(
    ckpt_loc,
    epoch,
)
fit_infer = system.make_fit_infer(outer_learnable=outer_learnable)
fs = 16000
aleksandra-bebe commented 7 months ago

@jmcasebeer Thank you very much, now it work !