I ran across this issue when trying to evaluate the log_prob of the guide at the positions from and HMC chain, to do this the positions need to be flattened into the same order the guide uses internally. I found that guide._unpack_latent will take a flat vector into a structured one, so I expected guide._unpack_latent._inverse to do the opposite and take the structured position into a flat one. What I found was the order did not match as expected.
Here is some reproducing code:
# make a model where the sample sites are not in alphabetical order
def model_funnel():
y = numpyro.sample("y", dist.Normal(0, 3))
numpyro.sample("x", dist.Normal(jnp.zeros(1), jnp.exp(y / 2)))
# make an auto guide
guide_funnel = numpyro.infer.autoguide.AutoDiagonalNormal(model_funnel)
# fit the guide
optim = numpyro.optim.Adam(step_size=1e-4)
svi = numpyro.infer.SVI(model_funnel, guide_funnel, optim, loss=numpyro.infer.Trace_ELBO())
svi_result = svi.run(jax.random.PRNGKey(0), 2000)
# make a flat vector to unpack and re-pack
values = jnp.array([1.0, 2.0])
unpack = guide_funnel._unpack_latent(values)
pack = guide_funnel._unpack_latent._inverse(unpack)
print(pack)
# [2.0, 1.0]
print(values)
# [1.0, 2.0]
From what I can tell ravel_pytree will always flatten the pytree in alphabetical order, but _unravel_dict will flatten in the order the sample sites show up in the model.
Posable solutions:
1) Update _unravel_dict to use ravel_pytree's unflatten function instead of defining a custom order (or emulate the same order that function would give)
2) Update UnpackTransform._inverse to invert the _unravel_dict operation rather than assume it is in the same order as ravel_pytree
Hi @CKrawczyk, I think we can raise NotImplementedError or require users to provide a packing function in that transform. That way you can use that "hidden" method guide._unpack_latent to pack stuffs.
I ran across this issue when trying to evaluate the
log_prob
of the guide at the positions from and HMC chain, to do this the positions need to be flattened into the same order the guide uses internally. I found thatguide._unpack_latent
will take a flat vector into a structured one, so I expectedguide._unpack_latent._inverse
to do the opposite and take the structured position into a flat one. What I found was the order did not match as expected.Here is some reproducing code:
What I think is going on is the
_inverse
method of theUnpackTransform
usesravel_pytree
: https://github.com/pyro-ppl/numpyro/blob/master/numpyro/distributions/transforms.py#L1137 But the forward transform uses the custom_unravel_dict
function: https://github.com/pyro-ppl/numpyro/blob/master/numpyro/infer/autoguide.py#L609From what I can tell
ravel_pytree
will always flatten the pytree in alphabetical order, but_unravel_dict
will flatten in the order the sample sites show up in the model.Posable solutions: 1) Update
_unravel_dict
to useravel_pytree
's unflatten function instead of defining a custom order (or emulate the same order that function would give) 2) UpdateUnpackTransform._inverse
to invert the_unravel_dict
operation rather than assume it is in the same order asravel_pytree