Closed EmanuelSommer closed 3 months ago
Good catch! the error means that this is run outside of numpyro context, in particular the seed handler. We should raise better errors here.
Thanks for the quick reply! Could you please shortly elaborate on how to fix the issue in general i.e. how to specify the context in a generic way that does not interfere with any downstream sampling tasks? Thanks in advance :)
To seed a program, you can put it under a handler
with handlers.seed(rng_seed=0):
# random_flax_module...
That's what I thought, just wanted to make sure I don't overlook any subtlety :) Thanks again!
First and foremost thanks for the great work on
numpyro
!The utility function
random_flax_module()
from thenumpyro.contrib.module
seems to be broken. As a minimal reproducible example for the error I take the example given in the docstring of the function (https://github.com/pyro-ppl/numpyro/blob/master/numpyro/contrib/module.py#L285) itself. The example with imports is given below:This leads directly to
ValueError: First argument passed to an init function should be a
jax.PRNGKeyor a dictionary mapping strings to
jax.PRNGKey.
Hope I just missed something obvious as the function would be very handy!