adobe-research / MetaAF

Control adaptive filters with neural networks.
https://jmcasebeer.github.io/projects/metaaf
218 stars 38 forks source link

when running system.infer in AEC task, shows ValueError: 'TimeChanCoupledGRU ...' #10

Closed DeboBurro closed 1 year ago

DeboBurro commented 1 year ago

According to the fig.4 in the Meta-AF: Meta-Learning for Adaptive Filters paper, I created a function that takes 2 lists, u and d, as input and I want it to be able to get the prediction from the pretrained model. But when I run it, it shows the error ValueError: 'TimeChanCoupledGRU/~/linear/w' with retrieved shape (20, 32) does not match shape=[16, 32] dtype=dtype('complex64'). Do you know why? Also, the prediction means y in the paper, right?

I copied the aec_eval.py to aec_get_output.py and created couple functions. But mostly the same.

 .
 .
 . 

def get_output(system, fit_infer, data_dict, out_dir, eval_kwargs, fs=16000):
    """
    given lists of d and u in data_dict,
    return the system prediction
    """
    u = get_u(data_dict) # [0.1, 0.2, .... , 0.1]  # list of 5000 elements
    d = get_d(data_dict) # [0.1, 0.2, .... , 0.1]  # list of 5000 elements
    print(f'u_len :{len(u)}, d_len: {len(d)}')
    e = [0]
    s = [0]
    max_len = len(u)
    u = np.pad(u, (0, max(0, max_len - len(u))), "wrap")
    d = np.pad(d, (0, max(0, max_len - len(d))), "wrap")
    e = np.pad(e, (0, max(0, max_len - len(e))), "wrap")
    s = np.pad(s, (0, max(0, max_len - len(s))), "wrap")
    u_new = u[:, None]
    d_new = d[:, None]
    e_new = e[:, None]
    s_new = s[:, None]
    print(f'Shapes :u {u_new.shape}, d {d_new.shape}, e {e_new.shape}, s {s_new.shape}')

    d_input = {"u": u_new[None], "d": d_new[None], "s": s_new[None], "e": e_new[None]}
    pred = system.infer({"signals": d_input, "metadata": {}}, fit_infer=fit_infer)[0]
    print(pred)

if __name__ == "__main__":

    # get checkpoint description from user
    parser = argparse.ArgumentParser()
    parser.add_argument("--name", type=str, default="")
    parser.add_argument("--date", type=str, default="")
    parser.add_argument("--epoch", type=int, default=0)
    parser.add_argument("--ckpt_dir", type=str, default="./meta_ckpts")

    # get evaluation conditions from user
    parser.add_argument("--universal", action="store_true", default=False)
    parser.add_argument("--system_len", type=int, default=None)

    # these will only get set if universal is false
    parser.add_argument("--true_rir_len", type=int, default=None)

    # decide what to save
    parser.add_argument("--out_dir", type=str, default="./meta_outputs")
    parser.add_argument("--save_outputs", action="store_true")
    parser.add_argument("--save_metrics", action="store_true")

    eval_kwargs = vars(parser.parse_args())
    pprint.pprint(eval_kwargs)

    # # build the checkpoint path
    ckpt_loc = os.path.join(
        eval_kwargs["ckpt_dir"], eval_kwargs["name"], eval_kwargs["date"]
    )
    epoch = int(eval_kwargs["epoch"])
    print(f'checkpoint location : {ckpt_loc}')

    # # load the checkpoint and kwargs file
    system, kwargs, outer_learnable = get_system_ckpt(
        ckpt_loc,
        epoch,
        system_len=eval_kwargs["system_len"],
    )
    fit_infer = system.make_fit_infer(outer_learnable=outer_learnable)

    # # build the outputs path
    out_dir = os.path.join(
        eval_kwargs["out_dir"],
        eval_kwargs["name"],
        eval_kwargs["date"],
        f"epoch_{epoch}",
    )
    if eval_kwargs["save_outputs"] or eval_kwargs["save_metrics"]:
        os.makedirs(out_dir, exist_ok=True)
    print(f'output dir: {out_dir}')

    # # name the filter and rir lengths
    true_rir_len = (
        "DEFAULT"
        if eval_kwargs["true_rir_len"] is None
        else eval_kwargs["true_rir_len"]
    )

    print(f'true RIR length : {true_rir_len}')
    system_len = (
        "DEFAULT" if eval_kwargs["system_len"] is None else eval_kwargs["system_len"]
    )
    print(f'system len : {system_len}')
    data_dict = get_data_dict('/home/burro/Downloads/Dataset.csv')
    predict = get_output(system, fit_infer, data_dict, out_dir, eval_kwargs)

error log:

(metaenv) burro@hostname:~/personal/MetaAF1.0.0/MetaAF-1.0.0/zoo/aec$ python aec_get_output.py --name meta_aec_16_combo_rl_4_1024_512 --date 2022_08_29_01_10_31 --epoch 110 --ckpt_dir ~/personal/MetaAF1.0.0/MetaAF-1.0.0/1.0.0-model/v1.0.0_models/aec
{'ckpt_dir': '/home/burro/personal/MetaAF1.0.0/MetaAF-1.0.0/1.0.0-model/v1.0.0_models/aec',
 'date': '2022_08_29_01_10_31',
 'epoch': 110,
 'name': 'meta_aec_16_combo_rl_4_1024_512',
 'out_dir': './meta_outputs',
 'save_metrics': False,
 'save_outputs': False,
 'system_len': None,
 'true_rir_len': None,
 'universal': False}
checkpoint location : /home/burro/personal/MetaAF1.0.0/MetaAF-1.0.0/1.0.0-model/v1.0.0_models/aec/meta_aec_16_combo_rl_4_1024_512/2022_08_29_01_10_31
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
output dir: ./meta_outputs/meta_aec_16_combo_rl_4_1024_512/2022_08_29_01_10_31/epoch_110
true RIR length : DEFAULT
system len : DEFAULT
u_len :5000, d_len: 5000
Shapes :(5000, 1), (5000, 1), (5000, 1), (5000, 1)
Traceback (most recent call last):
  File "aec_get_output.py", line 354, in <module>
    predict = get_output(system, fit_infer, data_dict, out_dir, eval_kwargs)
  File "aec_get_output.py", line 224, in get_output
    pred = system.infer({"signals": d_input, "metadata": {}}, fit_infer=fit_infer)[0]
  File "/home/burro/personal/MetaAF/metaaf/meta.py", line 804, in infer
    filter_s, filter_p, preprocess_s, postprocess_s, batch, key
  File "/home/burro/personal/MetaAF/metaaf/core.py", line 534, in fit_single
    batch_state, batch_hop, jnp.array(subkeys)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/_src/api.py", line 1686, in vmap_f
    ).call_wrapped(*args_flat)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/linear_util.py", line 167, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/_src/api.py", line 626, in cache_miss
    top_trace.process_call(primitive, fun_, tracers, params))
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/interpreters/batching.py", line 377, in process_call
    vals_out = call_primitive.bind(f_, *segment_lens, *vals, **params)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/core.py", line 2019, in bind
    outs = top_trace.process_call(self, fun_, tracers, params)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/core.py", line 715, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/_src/dispatch.py", line 250, in _xla_call_impl
    keep_unused=keep_unused)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/_src/dispatch.py", line 237, in _xla_call_impl_lazy
    *arg_specs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/linear_util.py", line 303, in memoized_fun
    ans = call(fun, *args)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/_src/dispatch.py", line 360, in _xla_callable_uncached
    keep_unused, *arg_specs).compile().unsafe_call
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/_src/dispatch.py", line 446, in lower_xla_callable
    fun, pe.debug_info_final(fun, "jit"))
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 2077, in trace_to_jaxpr_final2
    jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 2027, in trace_to_subjaxpr_dynamic2
    ans = fun.call_wrapped(*in_tracers_)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/linear_util.py", line 167, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/burro/personal/MetaAF/metaaf/core.py", line 445, in online_step
    opt_s = opt_update(0, filter_features, opt_s)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/example_libraries/optimizers.py", line 196, in tree_update
    new_states = map(partial(update, i), grad_flat, states)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/_src/util.py", line 78, in safe_map
    return list(map(f, *args))
  File "/home/burro/personal/MetaAF/metaaf/optimizer_fgru.py", line 183, in update
    **optimizer_kwargs,
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/transform.py", line 184, in apply_fn
    out, state = f.apply(params, None, *args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/transform.py", line 451, in apply_fn
    out = f(*args, **kwargs)
  File "/home/burro/personal/MetaAF/metaaf/optimizer_fgru.py", line 124, in _timechancoupled_gru_fwd
    return optimizer(x, h, extra_inputs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/module.py", line 465, in wrapped
    out = f(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/contextlib.py", line 74, in inner
    return func(*args, **kwds)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/module.py", line 306, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/home/burro/personal/MetaAF/metaaf/optimizer_fgru.py", line 83, in __call__
    rnn_in = self.preprocess_flatten(x, extra_inputs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/module.py", line 465, in wrapped
    out = f(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/contextlib.py", line 74, in inner
    return func(*args, **kwds)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/contextlib.py", line 74, in inner
    return func(*args, **kwds)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/module.py", line 306, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/home/burro/personal/MetaAF/metaaf/optimizer_fgru.py", line 67, in preprocess_flatten
    return self.in_lin(input_stack_flat)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/module.py", line 465, in wrapped
    out = f(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/contextlib.py", line 74, in inner
    return func(*args, **kwds)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/module.py", line 306, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/basic.py", line 125, in __call__
    out = layer(out, *args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/module.py", line 465, in wrapped
    out = f(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/contextlib.py", line 74, in inner
    return func(*args, **kwds)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/module.py", line 306, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/basic.py", line 178, in __call__
    w = hk.get_parameter("w", [input_size, output_size], dtype, init=w_init)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/base.py", line 603, in wrapped
    return wrapped._current(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/base.py", line 709, in get_parameter
    f"{fq_name!r} with retrieved shape {param.shape!r} does not match "
jax._src.traceback_util.UnfilteredStackTrace: ValueError: 'TimeChanCoupledGRU/~/linear/w' with retrieved shape (20, 32) does not match shape=[16, 32] dtype=dtype('complex64')

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "aec_get_output.py", line 354, in <module>
    predict = get_output(system, fit_infer, data_dict, out_dir, eval_kwargs)
  File "aec_get_output.py", line 224, in get_output
    pred = system.infer({"signals": d_input, "metadata": {}}, fit_infer=fit_infer)[0]
  File "/home/burro/personal/MetaAF/metaaf/meta.py", line 804, in infer
    filter_s, filter_p, preprocess_s, postprocess_s, batch, key
  File "/home/burro/personal/MetaAF/metaaf/core.py", line 534, in fit_single
    batch_state, batch_hop, jnp.array(subkeys)
  File "/home/burro/personal/MetaAF/metaaf/core.py", line 445, in online_step
    opt_s = opt_update(0, filter_features, opt_s)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/jax/example_libraries/optimizers.py", line 196, in tree_update
    new_states = map(partial(update, i), grad_flat, states)
  File "/home/burro/personal/MetaAF/metaaf/optimizer_fgru.py", line 183, in update
    **optimizer_kwargs,
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/transform.py", line 184, in apply_fn
    out, state = f.apply(params, None, *args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/transform.py", line 451, in apply_fn
    out = f(*args, **kwargs)
  File "/home/burro/personal/MetaAF/metaaf/optimizer_fgru.py", line 124, in _timechancoupled_gru_fwd
    return optimizer(x, h, extra_inputs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/module.py", line 465, in wrapped
    out = f(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/contextlib.py", line 74, in inner
    return func(*args, **kwds)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/module.py", line 306, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/home/burro/personal/MetaAF/metaaf/optimizer_fgru.py", line 83, in __call__
    rnn_in = self.preprocess_flatten(x, extra_inputs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/module.py", line 465, in wrapped
    out = f(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/contextlib.py", line 74, in inner
    return func(*args, **kwds)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/contextlib.py", line 74, in inner
    return func(*args, **kwds)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/module.py", line 306, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/home/burro/personal/MetaAF/metaaf/optimizer_fgru.py", line 67, in preprocess_flatten
    return self.in_lin(input_stack_flat)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/module.py", line 465, in wrapped
    out = f(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/contextlib.py", line 74, in inner
    return func(*args, **kwds)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/module.py", line 306, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/basic.py", line 125, in __call__
    out = layer(out, *args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/module.py", line 465, in wrapped
    out = f(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/contextlib.py", line 74, in inner
    return func(*args, **kwds)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/module.py", line 306, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/basic.py", line 178, in __call__
    w = hk.get_parameter("w", [input_size, output_size], dtype, init=w_init)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/base.py", line 603, in wrapped
    return wrapped._current(*args, **kwargs)
  File "/home/burro/personal/anaconda3/envs/metaenv/lib/python3.7/site-packages/haiku/_src/base.py", line 709, in get_parameter
    f"{fq_name!r} with retrieved shape {param.shape!r} does not match "
ValueError: 'TimeChanCoupledGRU/~/linear/w' with retrieved shape (20, 32) does not match shape=[16, 32] dtype=dtype('complex64')
DeboBurro commented 1 year ago

oh, I found why. All the input data lists u, d, e, s need to be (16000, 1) in order to run inference.

jmcasebeer commented 1 year ago

Yup, all meta-af modules expect a final channels/microphones dimension.

DeboBurro commented 1 year ago

think as long as the data size is larger than the window size or frame size or hop size.