pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.09k stars 227 forks source link

An auto guide's `_unpack_latent` and `_unpack_latent._inverse` don't use produce the same order #1809

Closed CKrawczyk closed 1 week ago

CKrawczyk commented 1 month ago

I ran across this issue when trying to evaluate the log_prob of the guide at the positions from and HMC chain, to do this the positions need to be flattened into the same order the guide uses internally. I found that guide._unpack_latent will take a flat vector into a structured one, so I expected guide._unpack_latent._inverse to do the opposite and take the structured position into a flat one. What I found was the order did not match as expected.

Here is some reproducing code:

# make a model where the sample sites are not in alphabetical order
def model_funnel():
    y = numpyro.sample("y", dist.Normal(0, 3))
    numpyro.sample("x", dist.Normal(jnp.zeros(1), jnp.exp(y / 2)))

# make an auto guide
guide_funnel = numpyro.infer.autoguide.AutoDiagonalNormal(model_funnel)

# fit the guide
optim = numpyro.optim.Adam(step_size=1e-4)
svi = numpyro.infer.SVI(model_funnel, guide_funnel, optim, loss=numpyro.infer.Trace_ELBO())
svi_result = svi.run(jax.random.PRNGKey(0), 2000)

# make a flat vector to unpack and re-pack
values = jnp.array([1.0, 2.0])
unpack = guide_funnel._unpack_latent(values)
pack = guide_funnel._unpack_latent._inverse(unpack)

print(pack)
# [2.0, 1.0]

print(values)
# [1.0, 2.0]

What I think is going on is the _inverse method of the UnpackTransform uses ravel_pytree: https://github.com/pyro-ppl/numpyro/blob/master/numpyro/distributions/transforms.py#L1137 But the forward transform uses the custom _unravel_dict function: https://github.com/pyro-ppl/numpyro/blob/master/numpyro/infer/autoguide.py#L609

From what I can tell ravel_pytree will always flatten the pytree in alphabetical order, but _unravel_dict will flatten in the order the sample sites show up in the model.

Posable solutions: 1) Update _unravel_dict to use ravel_pytree's unflatten function instead of defining a custom order (or emulate the same order that function would give) 2) Update UnpackTransform._inverse to invert the _unravel_dict operation rather than assume it is in the same order as ravel_pytree

fehiepsi commented 1 month ago

Hi @CKrawczyk, I think we can raise NotImplementedError or require users to provide a packing function in that transform. That way you can use that "hidden" method guide._unpack_latent to pack stuffs.