pymc-devs / pymc-experimental

https://pymc-experimental.readthedocs.io
Other
77 stars 49 forks source link

Bug in `get_domain_of_finite_discrete_rv` of `Categorical` #331

Open ricardoV94 opened 5 months ago

ricardoV94 commented 5 months ago

Reported by @jessegrabowski

with MarginalModel(coords=coords) as m:
    x_data = pm.ConstantData('x', df.x, dims=['obs_idx'])
    y_data = pm.ConstantData('y', df.y, dims=['obs_idx'])

    X = pt.concatenate([pt.ones_like(x_data[:, None]), x_data[:, None], x_data[:, None] ** 2], axis=-1)

    mu = pm.Normal('mu', dims=['group'])
    beta_p = pm.Normal('beta_p', dims=['params', 'group'])
    logit_p_group = X @ beta_p
    group_idx = pm.Categorical('group_idx', logit_p=logit_p_group, dims=['obs_idx'])
    sigma = pm.Exponential('sigma', 1)

    mu = pt.switch(pt.lt(group_idx, 1), 
                   mu_trend,
                   pt.switch(pt.lt(group_idx, 2), 
                             p_x[:, 0], 
                             p_x[:, 1])
                  )

    y_hat = pm.Normal('y_hat', 
                      mu = mu,
                      sigma = sigma,
                      observed=y_data,
                      dims=['obs_idx'])

m.marginalize(["group_idx"])
File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pymc_experimental/model/marginal_model.py:655, in get_domain_of_finite_discrete_rv(rv)
    653 elif isinstance(op, Categorical):
    654     p_param = rv.owner.inputs[3]
--> 655     return tuple(range(pt.get_vector_length(p_param)))
    656 elif isinstance(op, DiscreteUniform):
    657     lower, upper = constant_fold(rv.owner.inputs[3:])

File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/tensor/__init__.py:82, in get_vector_length(v)
     79 v = as_tensor_variable(v)
     81 if v.type.ndim != 1:
---> 82     raise TypeError(f"Argument must be a vector; got {v.type}")
     84 static_shape: Optional[int] = v.type.shape[0]
     85 if static_shape is not None:

TypeError: Argument must be a vector; got Matrix(float64, shape=(256, 3))

Instead of trying to get the vector length of p_param (which assumse p is always a vector), we should be constant folding p_param.shape[-1].