pymc-devs / pymc-experimental

https://pymc-experimental.readthedocs.io
Other
72 stars 46 forks source link

Error message from build_statespace_graph when cycle is one of the model components. #281

Closed rklees closed 6 months ago

rklees commented 6 months ago

When I build a PyMC statespace model with pymc-experimental, I always get an error message when adding a cycle to the model:

"TypeError: The type of the replacement (Vector(float64, shape=(2,))) must be compatible with the type of the original Variable (Vector(float64, shape=(1,)))."

Below the code. The error message refers to the line "annual cycle = ...".

mod = st.LevelTrendComponent(order=2, innovations_order=[0, 1]) mod += st.CycleComponent(name='annual_cycle', cycle_length=12, innovations=True) mod += st.MeasurementError(name="obs") model = mod.build(name="IRW+cycle+measurement_error") model.param_dims initial_trend_dims, sigma_trend_dims, annual_cycle_dims, sigma_obs_dims, P0_dims = model.param_dims.values() coords = model.coords with pm.Model(coords=coords) as model_1: P0_diag = pm.Gamma("P0_diag", alpha=2, beta=5, dims=P0_dims[0]) P0 = pm.Deterministic("P0", pt.diag(P0_diag), dims=P0_dims) initial_trend = pm.Normal("initial_trend", dims=initial_trend_dims) sigma_trend = pm.Gamma("sigma_trend", alpha=2, beta=10, dims=sigma_trend_dims) annual_cycle = pm.Normal("annual_cycle", sigma=5, dims=annual_cycle_dims) sigma_annual_cycle = pm.Gamma("sigma_annual_cycle", alpha=2, beta=5) sigma_obs = pm.Gamma("sigma_obs", alpha=2, beta=5, dims=['observed_state']) model.build_statespace_graph(data, mode="JAX") idata = pm.sample(nuts_sampler="numpyro", target_accept=0.9)

I guess the problem is that annual_cycle should have shape = (2,), but has shape = (1,). May that points to a bug in the base code? For instance, in line 1329, init_state gets shape=(1,), though I wonder if it should read shape=(2,). However, this alone does not fix the problem as I found; more adaptations may be needed.