Closed brandonwillard closed 2 years ago
Should we just canonicalize these to a at.stack([X_rv, Y_rv])[I_rv]
?
Should also work for more components
I_rv = srng.bernoulli(0.5, size=10, name="I")
X_rv = srng.normal(0, 1, name="X")
Y_rv = srng.gamma(0.5, 0.5, size=10, name="Y")
Z_rv = srng.halfnormal(0, 1, naze="Z")
Mix_rv = at.switch(
I_rv == 0,
X_rv,
at.switch(
I_rv == 1,
Y_rv,
Z_rv,
),
)
Wix_vv = Mix_rv.clone()
i_vv = I_rv.clone()
logp_parts = factorized_joint_logprob({Mix_rv: Mix_vv, I_rv: i_vv})
Should we just canonicalize these to a
at.stack([X_rv, Y_rv])[I_rv]
?
Good question. My first thought is that we would also need a "specialization" (or something similar) that undoes that canonicalization when/if np.switch
is more efficient to compute than the stacking and indexing.
(This is another one of those cases in which a relational approach—like miniKanren—would be rather convenient.)
Should we just canonicalize these to a
at.stack([X_rv, Y_rv])[I_rv]
?Good question. My first thought is that we would also need a "specialization" (or something similar) that undoes that canonicalization when/if
np.switch
is more efficient to compute than the stacking and indexing.
You mean when the switch is not actually associated with a value variable? Or do you mean the logprob expression should be defined with a switch because that could be more efficient?
You mean when the switch is not actually associated with a value variable? Or do you mean the logprob expression should be defined with a switch because that could be more efficient?
The latter.
We need to add support for mixtures defined using the
Elemwise
+Switch
Op
.The following is a simple example that isn't supported: