pymc-devs / pymc

Bayesian Modeling and Probabilistic Programming in Python
https://docs.pymc.io/
Other
8.63k stars 1.99k forks source link

Ordered transform incompatible with constrained space transforms (ZeroSum, Simplex, etc.) #6975

Open michaelosthege opened 10 months ago

michaelosthege commented 10 months ago

Describe the issue:

When applying the ordered transform to the ZSN the model suffers a -inf logp.

Reproduceable code example:

with pm.Model() as pmodel:
    pm.ZeroSumNormal("zsn", shape=2, transform=pm.distributions.transforms.ordered)
pmodel.debug()

Error message:

RuntimeWarning: divide by zero encountered in log
  variables = ufunc(*ufunc_args, **ufunc_kwargs)
point={'zsn_ordered__': array([  0., -inf])}

The variable zsn has the following parameters:
0: normal_rv{0, (0, 0), floatX, False}.1 [id A] <Vector(float64, shape=(2,))>
 ├─ RandomGeneratorSharedVariable(<Generator(PCG64) at 0x1E355016500>) [id B] <RandomGeneratorType>
 ├─ [2] [id C] <Vector(int64, shape=(1,))>
 ├─ 11 [id D] <Scalar(int64, shape=())>
 ├─ 0 [id E] <Scalar(int8, shape=())>
 └─ 1.0 [id F] <Scalar(float64, shape=())>
1: 1.0 [id F] <Scalar(float64, shape=())>
2: [2] [id G] <Vector(int32, shape=(1,))>
The parameters evaluate to:
0: [1.40383536 0.612628  ]
1: 1.0
2: [2]
Some of the values of variable zsn are associated with a non-finite logp:
 value = [  0. -inf] -> logp = -inf

PyMC version information:

5.9.1 at 419af0688353292c0a356cddbb9271737e89a723

Context for the issue:

No response

michaelosthege commented 10 months ago

I looked into it, but my understanding of the ZSN and Ordered transform is limited.

The Ordered transform works by doing a set-subtensor on a diff, and I guess this is what clashes with how the ZSN is also using set_subtensor in its logp?

ricardoV94 commented 10 months ago

You need to chain the OrderedTransform with the default transform ZSN has. Same thing as with any distribution that has a default transform.

ricardoV94 commented 10 months ago

Related to https://github.com/pymc-devs/pymc/issues/5674

michaelosthege commented 10 months ago

Okay, makes sense. I tried this, but it doesn't work, so obviously I'm doing something wrong:

with pm.Model() as pmodel:
    rv = pm.ZeroSumNormal.dist(shape=2)
    trafos = [
        pm.distributions.transforms._default_transform(rv.owner.op, rv),
        pm.distributions.transforms.ordered,
    ]
    trafo_chain = pm.distributions.transforms.Chain(trafos)
    pmodel.register_rv(rv, "zsn", transform=trafo_chain)

(idata.posterior["zsn"].sel(zsn_dim_0=0) < idata.posterior["zsn"].sel(zsn_dim_0=1)).mean({"chain", "draw"})
# array(0.50466667)
ricardoV94 commented 10 months ago

This works fine for me:

import pymc as pm

with pm.Model() as m:
    transform = pm.distributions.transforms.Chain([
        pm.distributions.transforms.ZeroSumTransform(zerosum_axes=(-1,)),
        pm.distributions.transforms.ordered,
    ])
    pm.ZeroSumNormal("zsn", shape=(3,), transform=transform, initval=[-1, 0, 1])
   # or pm.ZeroSumNormal("zsn", shape=(5,), transform=transform, initval=[-2, -1, 0, 1, 2])

    idata = pm.sample(chains=1, draws=100)
assert (idata.posterior["zsn"].sel(zsn_dim_0=0) < idata.posterior["zsn"].sel(zsn_dim_0=1)).all()

Might be something special about the ZSN with shape=(2,) as that is in fact only sampling one single value.

ricardoV94 commented 10 months ago

As implemented, the OrderedTransform is incompatible with the ZeroSum transform.

In the shape=(2,) case it doesn't do anything, because there is nothing to order in the latent case (the zerosum transform works on a n-1 space). So if the sampler proposes a positive value b, the final variable will be [b, -b] after the backward transform.

But even with more entries, ordering these n-1 entries, doesn't mean the zerosum transform will come out ordered (e.g., if all latent variables have a sum > 0 after the ordering, the missing entry, which is always the last, will have to be smaller than at least one previous entry to balance it out)

import pymc as pm

with pm.Model() as m:
    transform = pm.distributions.transforms.Chain([
        pm.distributions.transforms.ZeroSumTransform(zerosum_axes=(-1,)),
        pm.distributions.transforms.ordered,
    ])
    pm.ZeroSumNormal("zsn", shape=(2,), transform=transform)

zsn_latent_value = m.value_vars[-1]
zsn_value = m.unobserved_value_vars[-1]
print(zsn_value.eval({zsn_latent_value: [1.5]}))  # array([ 1.06066017, -1.06066017])

with pm.Model() as m:
    transform = pm.distributions.transforms.Chain([
        pm.distributions.transforms.ZeroSumTransform(zerosum_axes=(-1,)),
        pm.distributions.transforms.ordered,
    ])
    pm.ZeroSumNormal("zsn", shape=(3,), transform=transform)

zsn_latent_value = m.value_vars[-1]
zsn_value = m.unobserved_value_vars[-1]
print(zsn_value.eval({zsn_latent_value: [1.5, 2.5]}))  # [-1.70843849 10.47405547 -8.76561698]

The other way around doesn't work either, because the ordering transform distorts the values (there is an exp in the backward and a log in the forward steps), so multiple entries will no longer sum to zero after the chained backward transform.

ricardoV94 commented 10 months ago

This limitation is not specific to the ZeroSumTransform, it also applies for example to the SimplexTransform:

import pymc as pm

with pm.Model() as m:
    transform = pm.distributions.transforms.Chain([
        pm.distributions.transforms.simplex,
        pm.distributions.transforms.ordered,
    ])
    pm.Dirichlet("x", a=[1, 1, 1], shape=(3,), transform=transform)

zsn_latent_value = m.value_vars[-1]
zsn_value = m.unobserved_value_vars[-1]
print(zsn_value.eval({zsn_latent_value: [.5, -.75]}))  # [0.10714617 0.88997555 0.00287828]

Again the other order (simplex first, then ordering in the backward step) wouldn't provide a valid simplex anymore, because of the exp operation.

The OrderedTransform seems to only be useful in unconstrained spaces.

michaelosthege commented 10 months ago

Thank you for investigating!

Would it help to drop the log and exp from the Ordered transform? I experimented with this, and IIUC it was just added for numeric stability, which might already be subject to graph rewrites?

ricardoV94 commented 10 months ago

Thank you for investigating!

Would it help to drop the log and exp from the Ordered transform? I experimented with this, and IIUC it was just added for numeric stability, which might already be subject to graph rewrites?

No, it's is not about numerical stability. It's a way to achieve ordering in a continuous way. You could probably come up with a continuous constrained ordering but would need to find a different expression and make sure you have the correct jacobian for that.