pymc-devs / pymc-experimental

https://pymc-experimental.readthedocs.io
Other
84 stars 50 forks source link

Marginalization is reset by `freeze_dims_and_data` #383

Open jessegrabowski opened 1 month ago

jessegrabowski commented 1 month ago

Wasn't sure which repo this belongs in. If you marginalize a discrete variable with MarginalModel then call freeze_dims_and_data, the marginalization is undone:

import pymc as pm
from pymc_experimental import MarginalModel
from pymc.model.transform.optimization import freeze_dims_and_data
import pytensor.tensor as pt

with MarginalModel() as m:
    p = pm.Beta('p', 1, 1)
    idx = pm.Bernoulli('idx', p=p, size=(100,))
    mu = pm.Normal('mu', 0, [1, 100])
    x = pm.Normal('x', pm.math.switch(pt.eq(idx, 0) , mu[0], mu[1]), 1)

m.marginal(['idx'])
pm.inputvars(m.logp())   # [p_logodds__, mu, x]

pm.inputvars(freeze_dims_and_data(m).logp())  # Raises ValueError: Random variables detected in the logp graph
Full Traceback ``` --------------------------------------------------------------------------- ValueError Traceback (most recent call last) Cell In[19], line 1 ----> 1 pm.inputvars(freeze_dims_and_data(m).logp()) File ~/mambaforge/envs/readystate-bonds/lib/python3.11/site-packages/pymc/model/core.py:742, in Model.logp(self, vars, jacobian, sum) 740 rv_logps: list[TensorVariable] = [] 741 if rvs: --> 742 rv_logps = transformed_conditional_logp( 743 rvs=rvs, 744 rvs_to_values=self.rvs_to_values, 745 rvs_to_transforms=self.rvs_to_transforms, 746 jacobian=jacobian, 747 ) 748 assert isinstance(rv_logps, list) 750 # Replace random variables by their value variables in potential terms File ~/mambaforge/envs/readystate-bonds/lib/python3.11/site-packages/pymc/logprob/basic.py:630, in transformed_conditional_logp(rvs, rvs_to_values, rvs_to_transforms, jacobian, **kwargs) 628 rvs_in_logp_expressions = _find_unallowed_rvs_in_graph(logp_terms_list) 629 if rvs_in_logp_expressions: --> 630 raise ValueError(RVS_IN_JOINT_LOGP_GRAPH_MSG % rvs_in_logp_expressions) 632 return logp_terms_list ValueError: Random variables detected in the logp graph: {bernoulli_rv{"()->()"}.out}. This can happen when DensityDist logp or Interval transform functions reference nonlocal variables, or when not all rvs have a corresponding value variable. ```
ricardoV94 commented 1 month ago

MarginalModel is not compatible with any model transformations. It's a temporary limitation until we get rid of the subclass.