Closed A108669 closed 2 months ago
CC @jessegrabowski
Hey, thanks for giving the module a try!
I can reproduce the problem on my end, so it's definitely a bug. It looks like a JAX problem; I can run your code if I use the default PyMC sampler. Also, if you have more than one regressor it works. For example, this runs:
import pandas as pd
import numpy as np
# Generate dummy data with monthly seasonality and trend
n_samples = 1000
k_exog = 3
np.random.seed(100)
trend_data = np.arange(n_samples) * .1
true_betas = np.random.normal(size=(k_exog,))
regressor_data = np.random.normal(scale=2, size=(n_samples, k_exog))
y = trend_data + regressor_data @ true_betas + np.random.normal(scale=2, size=n_samples) + 10
df = pd.DataFrame(np.c_[regressor_data, y],
index = pd.date_range("2001-01-01", freq="M", periods=n_samples),
columns=['y'] + [f'x_{i}' for i in range(k_exog)])
df.index.freq = 'M'
trend = st.LevelTrendComponent(name="linear_trend", order=2, innovations_order=0)
regressor = st.RegressionComponent(name="xreg", k_exog=k_exog, state_names=['x0', 'x1', 'x2'])
error = st.MeasurementError(name="error")
mod = trend + error + regressor
ss_mod = mod.build(name="test")
trend_dims, obs_dims, regressor_dims, regression_data_dims, P0_dims = ss_mod.param_dims.values()
coords = ss_mod.coords
with pm.Model(coords=coords) as model_1:
data_xreg = pm.MutableData("data_xreg", df.drop(columns='y').values)
P0_diag = pm.Gamma("P0_diag", alpha=2, beta=5, dims=P0_dims[0])
P0 = pm.Deterministic("P0", pt.diag(P0_diag), dims=P0_dims)
initial_trend = pm.Normal("initial_trend", dims=trend_dims)
sigma_error = pm.Gamma("sigma_error", alpha=2, beta=5, dims=["observed_state"])
beta_xreg = pm.Normal("beta_xreg", .2, 1, dims=regressor_dims)
ss_mod.build_statespace_graph(df[['y']], mode='JAX')
idata = pm.sample(nuts_sampler='numpyro', target_accept=0.9)
# prior = pm.sample_prior_predictive(samples=10)
Probably the data is being incorrectly squeezed somewhere. I'll look closely and push a fix ASAP. Thanks for finding this bug and opening an issue!
I finally had some time to look closely at this. It appears to be a bug that arises because broadcastable dimensions are signified by a shape of 1
in pytensor. This means the program considers them dynamic, since they might change after broadcasting. As a result, JAX gets upset by the graph, because it doesn't allow dynamic shapes. This is why the model works if you have more than one exogenous variable -- the 2nd dimension of the exogenous data isn't 1
anymore, and everything is inferred to be static. Might be related to https://github.com/pymc-devs/pytensor/issues/408, but not sure.
For now, I can think of two possible work-arounds:
pm.MutableData
, by passing a shape
keyword argument.pm.ConstantData
instead of pm.MutableData
Despite my choice of ordering, I think option 2 is preferable.
Here is a working example:
import pandas as pd
import numpy as np
# Generate dummy data with monthly seasonality and trend
n_samples = 1000
k_exog = 1
np.random.seed(100)
trend_data = np.arange(n_samples) * .1
true_betas = np.random.normal(size=(k_exog,))
regressor_data = np.random.normal(scale=2, size=(n_samples, k_exog))
y = trend_data + regressor_data @ true_betas + np.random.normal(scale=2, size=n_samples) + 10
df = pd.DataFrame(np.c_[y, regressor_data],
index = pd.date_range("2001-01-01", freq="ME", periods=n_samples),
columns=['y'] + [f'x_{i}' for i in range(k_exog)])
df.index.freq = 'ME'
trend = st.LevelTrendComponent(name="linear_trend", order=2, innovations_order=0)
regressor = st.RegressionComponent(name="xreg", k_exog=k_exog, state_names=[f'x{i}' for i in range(k_exog)])
error = st.MeasurementError(name="error")
mod = trend + error + regressor
ss_mod = mod.build(name="test")
trend_dims, obs_dims, regressor_dims, P0_dims = ss_mod.param_dims.values()
coords = ss_mod.coords
with pm.Model(coords=coords) as model_1:
# Option 1:
data_xreg = pm.MutableData("data_xreg", df.drop(columns='y').values,
dims=['time', 'exog_state'],
shape=(n_samples, k_exog)) # <--- Key line
# Option 2:
# data_xreg = pm.ConstantData("data_xreg", df.drop(columns='y').values,
# dims=['time', 'exog_state'])
P0_diag = pm.Gamma("P0_diag", alpha=2, beta=5, dims=P0_dims[0])
P0 = pm.Deterministic("P0", pt.diag(P0_diag), dims=P0_dims)
initial_trend = pm.Normal("initial_trend", dims=trend_dims)
sigma_error = pm.Gamma("sigma_error", alpha=2, beta=5)
beta_xreg = pm.Normal("beta_xreg", .2, 1, dims=regressor_dims)
ss_mod.build_statespace_graph(df[['y']], mode='JAX')
idata = pm.sample(nuts_sampler='numpyro', target_accept=0.9)
Note that I also specified n_samples
. If you don't JAX will bark at you about dynamic shapes when you try to do post-estimation sampling (ss_mod.sample_conditional_posterior
, for example).
I'll let you know when I come up with a more long-term solution.
You only need to specify the shape for broadcastable dims if you intend it to broadcast. You can pass shape=(None, 1)
if you still want the other dim to be resizeable (but cannot broadcast it with other parameters)
Yes this works as well (with pm.MutableData
), but JAX will still error on pm.sample_posterior_predictive
, complaining about dynamic slicing. So I recommend to just declare both for now, since conditional forecasting with exogenous timeseries isn't support yet anyway.
This should be fixed by #326. Feel free to open a new issue if you're still hitting problems. I updated the structural example notebook, but it still needs more work. Still, it should give you an idea of how to include exogenous regressors.
Hello! I have been experimenting with the Structural Time Series module; however, I have been running into trouble when I attempt to add an external regressor to a model. Below is an example including dummy data as well as the full stack trace. I have not been able to get any external regressors added to the models.
Full Stack Trace: