pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.15k stars 236 forks source link

bug in NeuTraReparam #1694

Closed amifalk closed 3 months ago

amifalk commented 10 months ago

Minimal example:

import jax
import jax.numpy as jnp
from jax.random import PRNGKey

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Trace_ELBO, SVI
from numpyro.infer.reparam import NeuTraReparam
from numpyro.infer.autoguide import AutoBNAFNormal

n = 100
p = 10 # n_dim x
q = 5 # n_dim y
k = min(3, p, q) # n_dim latent

X = dist.MultivariateNormal(jnp.zeros(p), jnp.eye(p, p)).sample(PRNGKey(0), (n,))
Y = dist.MultivariateNormal(jnp.zeros(q), jnp.eye(q, q)).sample(PRNGKey(1), (n,))

def model(X, Y=None):    
    with numpyro.plate('_k', k):
         P_cov = numpyro.sample('P_cov', dist.InverseGamma(3, 1))

    with numpyro.plate('_q', q):
         Q_cov = numpyro.sample('Q_cov', dist.InverseGamma(3, 1))    

    P_cov = P_cov * jnp.eye(k, k)
    Q_cov = Q_cov * jnp.eye(q, q)

    with numpyro.plate('p', p):
        P = numpyro.sample('P', dist.MultivariateNormal(jnp.zeros(k), P_cov))

    with numpyro.plate('k', k):
        Q = numpyro.sample('Q', dist.MultivariateNormal(jnp.zeros(q), Q_cov))

    with numpyro.plate('n', n):
        Z = X @ P # low rank representation of X
        Y_pred = Z @ Q # transform back into Y via Q

        return numpyro.sample('Y', dist.MultivariateNormal(Y_pred, jnp.eye(q, q)), obs=Y)

#  --- this works ---
mcmc = MCMC(NUTS(model), num_warmup=50, num_samples=50)
mcmc.run(jax.random.PRNGKey(2), X, Y) 

# --- this fails ---
guide = AutoBNAFNormal(model, num_flows=1, hidden_factors=[8, 8])
svi = SVI(model, guide, numpyro.optim.Adam(0.003), Trace_ELBO())

svi_result = svi.run(jax.random.PRNGKey(3), 5_000, X, Y)
neutra = NeuTraReparam(guide, svi_result.params)

mcmc = MCMC(NUTS(neutra.reparam(model)), num_warmup=1_000, num_samples=3_000)
mcmc.run(jax.random.PRNGKey(4), X, Y)

I'm not entirely sure what's going on here. The following model works with vanilla NUTS, but returns TypeError: mul got incompatible shapes for broadcasting: (3, 5), (5, 5) when trying to run NUTS after reparameterizing with NeuTraReparam.

If I remove the top two plates and replace the latents with the constants

P_cov =  jnp.eye(k, k)
Q_cov = jnp.eye(q, q)

the code runs but I get the following warnings:

<ipython-input-17-7e6df362d6ff>:54: UserWarning: Missing a plate statement for batch dimension -2 at site '_P_log_prob'. You can use `numpyro.util.format_shapes` utility to check shapes at all sites of your model.
  mcmc.run(jax.random.PRNGKey(4), X, Y)
<ipython-input-17-7e6df362d6ff>:54: UserWarning: Missing a plate statement for batch dimension -2 at site '_Q_log_prob'. You can use `numpyro.util.format_shapes` utility to check shapes at all sites of your model.
  mcmc.run(jax.random.PRNGKey(4), X, Y)
<ipython-input-17-7e6df362d6ff>:54: UserWarning: Missing a plate statement for batch dimension -2 at site 'Y'. You can use `numpyro.util.format_shapes` utility to check shapes at all sites of your model.
  mcmc.run(jax.random.PRNGKey(4), X, Y)

Maybe it has something to do with having multiple plate names with the same dimension?

fehiepsi commented 10 months ago

Thanks @amifalk! This is a bug because we allow plate to be applied to the unconstrained value: https://github.com/pyro-ppl/numpyro/blob/b16741cc163b1a3753a331e3200c64cced9eb804/numpyro/infer/reparam.py#L283-L286

A temporary fix is to remove plate for the first site

P_cov = numpyro.sample('P_cov', dist.InverseGamma(3, 1).expand([k]).to_event())