adobe-research / MetaAF

Control adaptive filters with neural networks.
https://jmcasebeer.github.io/projects/metaaf
218 stars 38 forks source link

You must pass a non-None PRNGKey to init and/or apply if you make use of random numbers when running hoaec_eval.py #11

Closed apra-da closed 8 months ago

apra-da commented 1 year ago

I wish to run __hoaeceval.py_ trough the r_es_ckptsrun.sh bash script to try out the algorithms and I get the "You must pass a non-None PRNGKey to init and/or apply if you make use of random numbers when running hoaec_eval.py" error when running. The code is practically the same as the repository, I only changed the config.py constant values to match my folders. Here's the complete Traceback. Hope you can help!

Storing AEC outputs...
/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/torch/utils/data/dataloader.py:561: UserWarning: This DataLoader will create 10 worker processes in total. Our suggested max number of worker in current system is 8 (`cpuset` is not taken into account), which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(_create_warning_msg(
  0%|                                                                                                                                                                                                                                    | 0/32 [00:00<?, ?it/s]/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/metaaf/complex_gru.py:114: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
  return jax.tree_map(broadcast, nest)
  0%|                                                                                                                                                                                                                                    | 0/32 [00:11<?, ?it/s]
Traceback (most recent call last):
  File "/Users/agus/work/eye-predict/audioEngineering/External-repos/MetaAFPackage/hoaec_eval.py", line 225, in <module>
    preds = system.infer(data)[0]
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/metaaf/meta.py", line 825, in infer
    out, aux = fit_infer(
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/metaaf/core.py", line 549, in fit_single
    cur_out, loss, batch_state = batch_step(
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/_src/api.py", line 1564, in vmap_f
    out_flat = batching.batch(
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/linear_util.py", line 168, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/_src/api.py", line 526, in cache_miss
    out_flat = xla.xla_call(
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/core.py", line 1919, in bind
    return call_bind(self, fun, *args, **params)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/core.py", line 1935, in call_bind
    outs = top_trace.process_call(primitive, fun_, tracers, params)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/interpreters/batching.py", line 233, in process_call
    vals_out = call_primitive.bind(f_, *vals, **params)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/core.py", line 1919, in bind
    return call_bind(self, fun, *args, **params)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/core.py", line 1935, in call_bind
    outs = top_trace.process_call(primitive, fun_, tracers, params)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/core.py", line 687, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/_src/dispatch.py", line 199, in _xla_call_impl
    compiled_fun = xla_callable(fun, device, backend, name, donated_invars,
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/linear_util.py", line 295, in memoized_fun
    ans = call(fun, *args)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/_src/dispatch.py", line 248, in _xla_callable_uncached
    return lower_xla_callable(fun, device, backend, name, donated_invars, False,
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/_src/profiler.py", line 294, in wrapper
    return func(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/_src/dispatch.py", line 293, in lower_xla_callable
    jaxpr, out_type, consts = pe.trace_to_jaxpr_final2(
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/_src/profiler.py", line 294, in wrapper
    return func(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 2167, in trace_to_jaxpr_final2
    jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 2117, in trace_to_subjaxpr_dynamic2
    ans = fun.call_wrapped(*in_tracers_)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/linear_util.py", line 168, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/metaaf/core.py", line 462, in online_step
    opt_s = opt_update(0, filter_features, opt_s)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/example_libraries/optimizers.py", line 196, in tree_update
    new_states = map(partial(update, i), grad_flat, states)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/_src/util.py", line 47, in safe_map
    return list(map(f, *args))
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/metaaf/optimizer_hogru.py", line 471, in update
    update, state = optimizer.apply(
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/transform.py", line 128, in apply_fn
    out, state = f.apply(params, {}, *args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/transform.py", line 357, in apply_fn
    out = f(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/metaaf/optimizer_hogru.py", line 234, in _fwd
    return optimizer(x, h, extra_inputs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/module.py", line 434, in wrapped
    out = f(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/module.py", line 273, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/metaaf/optimizer_hogru.py", line 183, in __call__
    rnn_in = self.preprocess_flatten(x, extra_inputs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/module.py", line 434, in wrapped
    out = f(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/module.py", line 273, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/metaaf/optimizer_hogru.py", line 149, in preprocess_flatten
    input_stack_flat = self.in_coupling_conv(
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/module.py", line 434, in wrapped
    out = f(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/module.py", line 273, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/basic.py", line 123, in __call__
    out = layer(out, *args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/module.py", line 434, in wrapped
    out = f(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/module.py", line 273, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/conv.py", line 200, in __call__
    w = hk.get_parameter("w", w_shape, inputs.dtype, init=w_init)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/base.py", line 448, in wrapped
    return wrapped._current(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/base.py", line 515, in get_parameter
    param = init(shape, dtype)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/metaaf/complex_utils.py", line 12, in complex_variance_scaling
    real = hk.initializers.VarianceScaling()(shape, dtype=jnp.float32)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/initializers.py", line 215, in __call__
    return TruncatedNormal(stddev=stddev)(shape, dtype)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/initializers.py", line 114, in __call__
    unscaled = jax.random.truncated_normal(hk.next_rng_key(), -2., 2., shape,
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/base.py", line 448, in wrapped
    return wrapped._current(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/base.py", line 965, in next_rng_key
    return next_rng_key_internal()
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/base.py", line 1003, in next_rng_key_internal
    rng_seq = rng_seq_or_fail()
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/base.py", line 923, in rng_seq_or_fail
    raise ValueError("You must pass a non-None PRNGKey to init and/or apply "
jax._src.traceback_util.UnfilteredStackTrace: ValueError: You must pass a non-None PRNGKey to init and/or apply if you make use of random numbers.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/agus/work/eye-predict/audioEngineering/External-repos/MetaAFPackage/hoaec_eval.py", line 225, in <module>
    preds = system.infer(data)[0]
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/metaaf/meta.py", line 825, in infer
    out, aux = fit_infer(
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/metaaf/core.py", line 549, in fit_single
    cur_out, loss, batch_state = batch_step(
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/metaaf/core.py", line 462, in online_step
    opt_s = opt_update(0, filter_features, opt_s)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/jax/example_libraries/optimizers.py", line 196, in tree_update
    new_states = map(partial(update, i), grad_flat, states)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/metaaf/optimizer_hogru.py", line 471, in update
    update, state = optimizer.apply(
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/transform.py", line 128, in apply_fn
    out, state = f.apply(params, {}, *args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/transform.py", line 357, in apply_fn
    out = f(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/metaaf/optimizer_hogru.py", line 234, in _fwd
    return optimizer(x, h, extra_inputs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/module.py", line 434, in wrapped
    out = f(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/module.py", line 273, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/metaaf/optimizer_hogru.py", line 183, in __call__
    rnn_in = self.preprocess_flatten(x, extra_inputs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/module.py", line 434, in wrapped
    out = f(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/module.py", line 273, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/metaaf/optimizer_hogru.py", line 149, in preprocess_flatten
    input_stack_flat = self.in_coupling_conv(
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/module.py", line 434, in wrapped
    out = f(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/module.py", line 273, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/basic.py", line 123, in __call__
    out = layer(out, *args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/module.py", line 434, in wrapped
    out = f(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/module.py", line 273, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/conv.py", line 200, in __call__
    w = hk.get_parameter("w", w_shape, inputs.dtype, init=w_init)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/base.py", line 448, in wrapped
    return wrapped._current(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/base.py", line 515, in get_parameter
    param = init(shape, dtype)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/metaaf/complex_utils.py", line 12, in complex_variance_scaling
    real = hk.initializers.VarianceScaling()(shape, dtype=jnp.float32)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/initializers.py", line 215, in __call__
    return TruncatedNormal(stddev=stddev)(shape, dtype)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/initializers.py", line 114, in __call__
    unscaled = jax.random.truncated_normal(hk.next_rng_key(), -2., 2., shape,
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/base.py", line 448, in wrapped
    return wrapped._current(*args, **kwargs)
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/base.py", line 965, in next_rng_key
    return next_rng_key_internal()
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/base.py", line 1003, in next_rng_key_internal
    rng_seq = rng_seq_or_fail()
  File "/Users/agus/miniforge3/envs/metaenv/lib/python3.10/site-packages/haiku/_src/base.py", line 923, in rng_seq_or_fail
    raise ValueError("You must pass a non-None PRNGKey to init and/or apply "
ValueError: You must pass a non-None PRNGKey to init and/or apply if you make use of random numbers.
jmcasebeer commented 1 year ago

Hey, thanks for the question. I looked into this and it seems I missed part of hoaec when refactoring. I've created a new branch with a fix. Can you let me know if that works for you?