pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.21k stars 241 forks source link

MCMC + random_flax_module don't work when specifying priors #1873

Closed sheinkmana closed 1 month ago

sheinkmana commented 1 month ago

Hey,

I am trying to implement a feed-forward neural network using flax module and run into a problem. SVI works fine but MCMC (tried NUTS/HMC) doesn't work whenever I specify any priors except for setting them all the same (e.g. setting prior=dist.Normal() works but prior={"bias":Normal(), "kernel": dist.Normal()} doesn't). I don't get any errors, it just completes the inference in a matter of seconds and the posterior predictive distribution is no different from the prior predictive. My flax part:

class FlaxMLP_relu(nn.Module):
    hidden_dims: Sequence[int]

    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:

        x = nn.Dense(features = self.hidden_dims[0], name=f"Dense{1}")(x[..., None])
        x = nn.relu(x)
        for i, hidden_dim in enumerate(self.hidden_dims[1:], 2):
            x = nn.Dense(features = hidden_dim, name=f"Dense{i}")(x)
            x = nn.relu(x)

        x = nn.Dense(1)(x)
        return x.squeeze()

The model with priors from the documentation (this and more complicated versions with scaling work for SVI fine):

    def model(self, x, y=None, hidden_dims = hidden_dims) -> None:
        net = random_flax_module("nn", FlaxMLP_sigmoid(hidden_dims), input_shape=(1,1), prior={"bias":Cauchy(), "kernel": dist.Normal()})
        mu = numpyro.deterministic('mu', net(x))
        prec_obj = numpyro.sample("prec_obj", dist.HalfNormal(0.01))
        sigma = prec_obj
        numpyro.sample("obs", dist.Normal(mu, sigma), obs=y)

Would you be able to help?

Huge thanks in advance!

fehiepsi commented 1 month ago

Hi @sheinkmana, I guess the keys in prior={"bias":Cauchy(), "kernel": dist.Normal()} do not match the ones in the neural network.

sheinkmana commented 1 month ago

Thanks a lot, @fehiepsi! That's super helpful. (Not going to lie - it took me an embarrassingly long time to finally realize it)