aesara-devs / aeppl

Tools for an Aesara-based PPL.
https://aeppl.readthedocs.io
MIT License
64 stars 21 forks source link

Support mixtures defined with `switch` #77

Closed brandonwillard closed 2 years ago

brandonwillard commented 3 years ago

We need to add support for mixtures defined using the Elemwise + Switch Op.

The following is a simple example that isn't supported:

import aesara.tensor as at

from aeppl.joint_logprob import factorized_joint_logprob

srng = at.random.RandomStream(seed=2320)

I_rv = srng.bernoulli(0.5, size=10, name="I")
X_rv = srng.normal(0, 1, name="X")
Y_rv = srng.gamma(0.5, 0.5, size=10, name="Y")

Z_rv = at.switch(I_rv, X_rv, Y_rv)
Z_rv.name = "Z"

z_vv = Z_rv.clone()
i_vv = I_rv.clone()
logp_parts = factorized_joint_logprob({Z_rv: z_vv, I_rv: i_vv})
ricardoV94 commented 3 years ago

Should we just canonicalize these to a at.stack([X_rv, Y_rv])[I_rv]?

Should also work for more components

I_rv = srng.bernoulli(0.5, size=10, name="I")

X_rv = srng.normal(0, 1, name="X")
Y_rv = srng.gamma(0.5, 0.5, size=10, name="Y")
Z_rv = srng.halfnormal(0, 1, naze="Z")

Mix_rv = at.switch(
    I_rv == 0, 
    X_rv, 
    at.switch(
        I_rv == 1,
        Y_rv, 
        Z_rv,
    ),
)

Wix_vv = Mix_rv.clone()
i_vv = I_rv.clone()
logp_parts = factorized_joint_logprob({Mix_rv: Mix_vv, I_rv: i_vv})
brandonwillard commented 3 years ago

Should we just canonicalize these to a at.stack([X_rv, Y_rv])[I_rv]?

Good question. My first thought is that we would also need a "specialization" (or something similar) that undoes that canonicalization when/if np.switch is more efficient to compute than the stacking and indexing.

(This is another one of those cases in which a relational approach—like miniKanren—would be rather convenient.)

ricardoV94 commented 2 years ago

Should we just canonicalize these to a at.stack([X_rv, Y_rv])[I_rv]?

Good question. My first thought is that we would also need a "specialization" (or something similar) that undoes that canonicalization when/if np.switch is more efficient to compute than the stacking and indexing.

You mean when the switch is not actually associated with a value variable? Or do you mean the logprob expression should be defined with a switch because that could be more efficient?

brandonwillard commented 2 years ago

You mean when the switch is not actually associated with a value variable? Or do you mean the logprob expression should be defined with a switch because that could be more efficient?

The latter.