Open michaelosthege opened 10 months ago
I looked into it, but my understanding of the ZSN and Ordered transform is limited.
The Ordered transform works by doing a set-subtensor on a diff, and I guess this is what clashes with how the ZSN is also using set_subtensor
in its logp?
You need to chain the OrderedTransform with the default transform ZSN has. Same thing as with any distribution that has a default transform.
Okay, makes sense. I tried this, but it doesn't work, so obviously I'm doing something wrong:
with pm.Model() as pmodel:
rv = pm.ZeroSumNormal.dist(shape=2)
trafos = [
pm.distributions.transforms._default_transform(rv.owner.op, rv),
pm.distributions.transforms.ordered,
]
trafo_chain = pm.distributions.transforms.Chain(trafos)
pmodel.register_rv(rv, "zsn", transform=trafo_chain)
(idata.posterior["zsn"].sel(zsn_dim_0=0) < idata.posterior["zsn"].sel(zsn_dim_0=1)).mean({"chain", "draw"})
# array(0.50466667)
This works fine for me:
import pymc as pm
with pm.Model() as m:
transform = pm.distributions.transforms.Chain([
pm.distributions.transforms.ZeroSumTransform(zerosum_axes=(-1,)),
pm.distributions.transforms.ordered,
])
pm.ZeroSumNormal("zsn", shape=(3,), transform=transform, initval=[-1, 0, 1])
# or pm.ZeroSumNormal("zsn", shape=(5,), transform=transform, initval=[-2, -1, 0, 1, 2])
idata = pm.sample(chains=1, draws=100)
assert (idata.posterior["zsn"].sel(zsn_dim_0=0) < idata.posterior["zsn"].sel(zsn_dim_0=1)).all()
Might be something special about the ZSN with shape=(2,)
as that is in fact only sampling one single value.
As implemented, the OrderedTransform is incompatible with the ZeroSum transform.
In the shape=(2,)
case it doesn't do anything, because there is nothing to order in the latent case (the zerosum transform works on a n-1 space). So if the sampler proposes a positive value b
, the final variable will be [b, -b]
after the backward transform.
But even with more entries, ordering these n-1 entries, doesn't mean the zerosum transform will come out ordered (e.g., if all latent variables have a sum > 0 after the ordering, the missing entry, which is always the last, will have to be smaller than at least one previous entry to balance it out)
import pymc as pm
with pm.Model() as m:
transform = pm.distributions.transforms.Chain([
pm.distributions.transforms.ZeroSumTransform(zerosum_axes=(-1,)),
pm.distributions.transforms.ordered,
])
pm.ZeroSumNormal("zsn", shape=(2,), transform=transform)
zsn_latent_value = m.value_vars[-1]
zsn_value = m.unobserved_value_vars[-1]
print(zsn_value.eval({zsn_latent_value: [1.5]})) # array([ 1.06066017, -1.06066017])
with pm.Model() as m:
transform = pm.distributions.transforms.Chain([
pm.distributions.transforms.ZeroSumTransform(zerosum_axes=(-1,)),
pm.distributions.transforms.ordered,
])
pm.ZeroSumNormal("zsn", shape=(3,), transform=transform)
zsn_latent_value = m.value_vars[-1]
zsn_value = m.unobserved_value_vars[-1]
print(zsn_value.eval({zsn_latent_value: [1.5, 2.5]})) # [-1.70843849 10.47405547 -8.76561698]
The other way around doesn't work either, because the ordering transform distorts the values (there is an exp in the backward and a log in the forward steps), so multiple entries will no longer sum to zero after the chained backward transform.
This limitation is not specific to the ZeroSumTransform, it also applies for example to the SimplexTransform:
import pymc as pm
with pm.Model() as m:
transform = pm.distributions.transforms.Chain([
pm.distributions.transforms.simplex,
pm.distributions.transforms.ordered,
])
pm.Dirichlet("x", a=[1, 1, 1], shape=(3,), transform=transform)
zsn_latent_value = m.value_vars[-1]
zsn_value = m.unobserved_value_vars[-1]
print(zsn_value.eval({zsn_latent_value: [.5, -.75]})) # [0.10714617 0.88997555 0.00287828]
Again the other order (simplex first, then ordering in the backward step) wouldn't provide a valid simplex anymore, because of the exp operation.
The OrderedTransform seems to only be useful in unconstrained spaces.
Thank you for investigating!
Would it help to drop the log
and exp
from the Ordered
transform? I experimented with this, and IIUC it was just added for numeric stability, which might already be subject to graph rewrites?
Thank you for investigating!
Would it help to drop the
log
andexp
from theOrdered
transform? I experimented with this, and IIUC it was just added for numeric stability, which might already be subject to graph rewrites?
No, it's is not about numerical stability. It's a way to achieve ordering in a continuous way. You could probably come up with a continuous constrained ordering but would need to find a different expression and make sure you have the correct jacobian for that.
Describe the issue:
When applying the ordered transform to the ZSN the model suffers a
-inf
logp.Reproduceable code example:
Error message:
PyMC version information:
5.9.1 at 419af0688353292c0a356cddbb9271737e89a723
Context for the issue:
No response