pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.1k stars 227 forks source link

random_flax_module broken #1741

Closed EmanuelSommer closed 3 months ago

EmanuelSommer commented 5 months ago

First and foremost thanks for the great work on numpyro!

The utility function random_flax_module() from the numpyro.contrib.module seems to be broken. As a minimal reproducible example for the error I take the example given in the docstring of the function (https://github.com/pyro-ppl/numpyro/blob/master/numpyro/contrib/module.py#L285) itself. The example with imports is given below:

import flax
import numpyro.distributions as dist
from numpyro.contrib.module import random_flax_module
random_flax_module(
    "net",
    flax.linen.Dense(features=1),
    prior={"bias": dist.Cauchy(), "kernel": dist.Normal()},
    input_shape=(4,)
)

This leads directly to ValueError: First argument passed to an init function should be ajax.PRNGKeyor a dictionary mapping strings tojax.PRNGKey.

Hope I just missed something obvious as the function would be very handy!

fehiepsi commented 5 months ago

Good catch! the error means that this is run outside of numpyro context, in particular the seed handler. We should raise better errors here.

EmanuelSommer commented 5 months ago

Thanks for the quick reply! Could you please shortly elaborate on how to fix the issue in general i.e. how to specify the context in a generic way that does not interfere with any downstream sampling tasks? Thanks in advance :)

fehiepsi commented 4 months ago

To seed a program, you can put it under a handler

with handlers.seed(rng_seed=0):
    # random_flax_module...
EmanuelSommer commented 4 months ago

That's what I thought, just wanted to make sure I don't overlook any subtlety :) Thanks again!