pymc-devs / pymc-experimental

https://pymc-experimental.readthedocs.io
Other
72 stars 46 forks source link

including cyclic or seasonal components causes error messages from build_statespace_graph since last bug fix #289

Closed rklees closed 6 months ago

rklees commented 6 months ago

I still do not get pymc_experimental running when adding a cycle and/or a seasonal to the model. Below an example code. In this case, the model consists of an integrated random walk + cycle + seasonal with 1 harmonic + measurement noise. I can generate the data associated with this model, but when trying to integrate into pymc I get an error message from build_statespace_graph. The error message points to inconsistencies in the shape of the cyclic and/or seasonal component (shape (2,) versus shape (1,) ) in base code. The error appears when I add any combination of cyclic and/or seasonal components. Before the last error fix related to cyclic components documented on Github it only appeared when a cyclic component was added.

Here is the code and the error message:

import jax

jax.config.update("jax_platform_name", "cpu") import numpyro

import blackjax

import nutpie

numpyro.set_host_device_count(4)

import sys

sys.path.append("..") print(sys.path)

import pymc_experimental.statespace

import importlib importlib.reload(pymc_experimental.statespace.structural)

from pymc_experimental.statespace import structural as st from pymc_experimental.statespace.utils.constants import SHORT_NAME_TO_LONG, MATRIX_NAMES import matplotlib.pyplot as plt import pymc as pm import arviz as az import pytensor import pytensor.tensor as pt import numpy as np import pandas as pd from patsy import dmatrix from pymc_experimental.statespace.core.representation import PytensorRepresentation import xarray as xr

%reload_ext autoreload %autoreload complete

from importlib.metadata import version print('pymc version = ', version('pymc')) print('pytensor version = ', version('pytensor')) # print('pandas version = ', version('pandas')) print('pandas version = ', pd.version) print('arviz version = ', version('arviz')) print('numpy version = ', version('numpy')) print('pytensor version = ', version('pytensor'))

print('blackjax version = ', version('blackjax'))

print('nutpie version = ', version('nutpie')) print('xarray version = ', version('xarray'))

plt.rcParams.update( { "figure.figsize": (14, 4), "figure.dpi": 144, "figure.constrained_layout.use": True, "axes.grid": True, "grid.linewidth": 0.5, "grid.linestyle": "--", "axes.spines.top": False, "axes.spines.bottom": False, "axes.spines.left": False, "axes.spines.right": False, } )

def unpack_statespace(ssm): return [ssm[SHORT_NAME_TO_LONG[x]] for x in MATRIX_NAMES]

def unpack_symbolic_matrices_with_params(mod, param_dict): f_matrices = pytensor.function( list(mod._name_to_variable.values()), unpack_statespace(mod.ssm), on_unused_input="ignore" ) x0, P0, c, d, T, Z, R, H, Q = f_matrices(**param_dict) return x0, P0, c, d, T, Z, R, H, Q

def simulate_from_numpy_model(mod, rng, param_dict, steps=100): """ Helper function to visualize the components outside of a PyMC model context """ x0, P0, c, d, T, Z, R, H, Q = unpack_symbolic_matrices_with_params(mod, param_dict) Z_time_varies = Z.ndim == 3

k_states = mod.k_states
k_posdef = mod.k_posdef

x = np.zeros((steps, k_states))
y = np.zeros(steps)

x[0] = x0
if Z_time_varies:
    y[0] = Z[0] @ x0
else:
    y[0] = Z @ x0

if not np.allclose(H, 0):
    y[0] += rng.multivariate_normal(mean=np.zeros(1), cov=H)

for t in range(1, steps):
    if k_posdef > 0:
        shock = rng.multivariate_normal(mean=np.zeros(k_posdef), cov=Q)
        innov = R @ shock
    else:
        innov = 0

    if not np.allclose(H, 0):
        error = rng.multivariate_normal(mean=np.zeros(1), cov=H)
    else:
        error = 0

    x[t] = c + T @ x[t - 1] + innov

    if Z_time_varies:
        y[t] = d + Z[t] @ x[t] + error
    else:
        y[t] = d + Z @ x[t] + error

return x, y

def simulate_many_trajectories(mod, rng, param_dict, n_simulations, steps=100): k_states = mod.k_states k_posdef = mod.k_posdef

xs = np.zeros((n_simulations, steps, k_states))
ys = np.zeros((n_simulations, steps))

for i in range(n_simulations):
    x, y = simulate_from_numpy_model(mod, rng, param_dict, steps)
    xs[i] = x
    ys[i] = y
return xs, ys

seed = sum(map(ord, "Structural Timeseries")) rng = np.random.default_rng(seed)

measurement_error = st.MeasurementError(name="obs") IRW = st.LevelTrendComponent(order=2, innovations_order=[0, 1]) cycle = st.CycleComponent(name="annual_cycle", cycle_length=12, innovations=True) # cycle length is the period in number of samples; non-integer periods are allowed SA_cycle = st.FrequencySeasonality( name="SA_cycle", season_length=5.347, n=1, innovations=True # season_length is the period in units of number of sampels; non-integer values allowed )

param_dict = { "initial_trend": np.zeros((2,)), "sigma_trend": np.array([0.2]), "annual_cycle": np.array([10., 0.]), "sigma_annual_cycle": np.array([1.0]),
"SA_cycle": np.array([20., 0.]), "sigma_SA_cycle": np.array([0.5]), "sigma_obs": np.array([0.1]), }

mod = IRW + cycle + SA_cycle + measurement_error

x, y = simulate_from_numpy_model(mod, rng, param_dict, steps=144)

plt.figure(figsize=(10,5)) plt.plot(y), plt.title('IRW plus annual cycle plus SA_cycle')

plt.figure(figsize=(10,5)) plt.plot(x[:, 0]), plt.title('level component')

plt.figure(figsize=(10,5)) plt.plot(x[:, 1]), plt.title('trend component')

plt.figure(figsize=(10,5)) plt.plot(x[:, 2]), plt.title('annual cycle component')

plt.figure(figsize=(10,5)) plt.plot(x[:, 4]), plt.title('SA cycle component')

time = np.arange(144)/12 data = pd.DataFrame({'time': time, 'meas': y}) nobs = len(data['meas']) dt = np.mean(np.diff(data['time'])) # sampling period in units of years

mod = st.LevelTrendComponent(order=2, innovations_order=[0, 1]) mod += st.CycleComponent(name='annual_cycle', cycle_length=12, innovations=True) mod += st.FrequencySeasonality(name='SA_cycle', season_length=5.347, n=1, innovations=True) mod += st.MeasurementError(name="obs")

model = mod.build(name="IRW+cycle+measurement_error")

with pm.Model(coords=coords) as model_1: P0_diag = pm.Gamma("P0_diag", alpha=2, beta=5, dims=P0_dims[0]) P0 = pm.Deterministic("P0", pt.diag(P0_diag), dims=P0_dims) initial_trend = pm.Normal("initial_trend", dims=initial_trend_dims) sigma_trend = pm.Gamma("sigma_trend", alpha=2, beta=10, dims=sigma_trend_dims) annual_cycle = pm.Normal("annual_cycle", sigma=5) sigma_annual_cycle = pm.Gamma("sigma_annual_cycle", alpha=2, beta=5) SA_cycle = pm.Normal("SA_cycle", sigma=5) sigma_SA_cycle = pm.Gamma("sigma_SA_cycle", alpha=2, beta=5) sigma_obs = pm.Gamma("sigma_obs", alpha=2, beta=5, dims=('observed_state',))

with model_1:
model.build_statespace_graph(data['meas'], mode="JAX")


TypeError Traceback (most recent call last) Cell In[12], line 2 1 with model_1:
----> 2 model.build_statespace_graph(data['meas'], mode="JAX")

File /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/pymc_experimental/statespace/core/statespace.py:779, in PyMCStateSpace.build_statespace_graph(self, data, register_data, mode, missing_fill_value, cov_jitter, return_updates, include_smoother) 721 """ 722 Given a parameter vector theta, constructs the full computational graph describing the state space model and 723 the associated log probability of the data. Hidden states and log probabilities are computed via the Kalman (...) 775 If return_updates is False, the method will return None. 776 """ 777 pm_mod = modelcontext(None) --> 779 self._insert_random_variables() 780 obs_coords = pm_mod.coords.get(OBS_STATE_DIM, None) 782 self.data_len = data.shape[0]

File /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/pymc_experimental/statespace/core/statespace.py:637, in PyMCStateSpace._insert_random_variables(self) 633 matrices = list(self._unpack_statespace_with_placeholders()) 634 replacement_dict = { 635 var: pt.atleast_1d(pymc_model[name]) for name, var in self._name_to_variable.items() 636 } --> 637 self.subbed_ssm = graph_replace(matrices, replace=replacement_dict, strict=True)

File /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/pytensor/graph/replace.py:205, in graph_replace(outputs, replace, strict) 197 raise ValueError(f"{key} is not a part of graph") 199 sorted_replacements = sorted( 200 fg_replace.items(), 201 # sort based on the fg toposort, if a variable has no owner, it goes first 202 key=partial(toposort_key, fg, toposort), 203 reverse=True, 204 ) --> 205 fg.replace_all(sorted_replacements, import_missing=True) 206 if as_list: 207 return list(fg.outputs)

File /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/pytensor/graph/fg.py:515, in FunctionGraph.replace_all(self, pairs, kwargs) 513 """Replace variables in the FunctionGraph according to (var, new_var) pairs in a list.""" 514 for var, new_var in pairs: --> 515 self.replace(var, new_var, kwargs)

File /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/pytensor/graph/fg.py:508, in FunctionGraph.replace(self, var, new_var, reason, verbose, import_missing) 501 raise AssertionError( 502 "The replacement variable has a test value with " 503 "a shape different from the original variable's " 504 f"test value. Original: {tval_shape}, new: {new_tval_shape}" 505 ) 507 for node, i in list(self.clients[var]): --> 508 self.change_node_input( 509 node, i, new_var, reason=reason, import_missing=import_missing 510 )

File /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/pytensor/graph/fg.py:428, in FunctionGraph.change_node_input(self, node, i, new_var, reason, import_missing, check) 426 r = node.inputs[i] 427 if check and not r.type.is_super(new_var.type): --> 428 raise TypeError( 429 f"The type of the replacement ({new_var.type}) must be " 430 f"compatible with the type of the original Variable ({r.type})." 431 ) 432 node.inputs[i] = new_var 434 if r is new_var:

TypeError: The type of the replacement (Vector(float64, shape=(1,))) must be compatible with the type of the original Variable (Vector(float64, shape=(2,))).