pymc-devs / pymc-experimental

https://pymc-experimental.readthedocs.io
Other
72 stars 46 forks source link

`statespace`: Leveraging RegressionComponent yields error #297

Closed A108669 closed 2 months ago

A108669 commented 6 months ago

Hello! I have been experimenting with the Structural Time Series module; however, I have been running into trouble when I attempt to add an external regressor to a model. Below is an example including dummy data as well as the full stack trace. I have not been able to get any external regressors added to the models.


import pandas as pd
import numpy as np

# Generate dummy data with monthly seasonality and trend
n_samples = 1000
np.random.seed(100)

trend_data = np.arange(n_samples) * .1
regressor_data = np.random.normal(scale=2, size=n_samples)
y = trend_data + regressor_data + np.random.normal(scale=2, size=n_samples) + 10
df = pd.DataFrame(
    data={
        'time_index': pd.date_range("2001-01-01", freq="M", periods=len(y)),
        'x': regressor_data,
        'y': y,
    }
)

trend = st.LevelTrendComponent(name="linear_trend", order=2, innovations_order=0)
regressor = st.RegressionComponent(name="xreg", k_exog=1)
error = st.MeasurementError(name="error")

df = df.set_index("time_index")
df.index.freq = 'M'

mod = trend + error + regressor
ss_mod = mod.build(name="test")
sigma_trend_dims, sigma_obs_dims, regressor_dims, _, P0_dims = ss_mod.param_dims.values()
coords = ss_mod.coords

with pm.Model(coords=coords) as model_1:
    P0_diag = pm.Gamma("P0_diag", alpha=2, beta=5, dims=P0_dims[0])
    P0 = pm.Deterministic("P0", pt.diag(P0_diag), dims=P0_dims)

    initial_trend = pm.Normal("initial_trend", dims=initial_trend_dims)
    sigma_error = pm.Gamma("sigma_error", alpha=2, beta=5, dims=["observed_state"])

    beta_xreg = pm.Normal("beta_xreg", .2, 1)
    data_xreg = pm.MutableData("data_xreg", df[["x"]])

    ss_mod.build_statespace_graph(df[['y']], mode="JAX")
    idata = pm.sample(nuts_sampler="numpyro", target_accept=0.9)
ValueError: Argument [[-0.85651 ... 33556989]] given to the scan node is not compatible with its corresponding loop function variable *4-<Matrix(float64, shape=(?, ?))>

Full Stack Trace:

The following parameters should be assigned priors inside a PyMC model block: 
        initial_trend -- shape: (2,), constraints: None, dims: ('trend_state',)
        sigma_error -- shape: (1,), constraints: Positive, dims: None
        beta_xreg -- shape: (1,), constraints: None, dims: ('exog_state',)
        data_xreg -- shape: (None, 1), constraints: None, dims: ('time', 'exog_state')
        P0 -- shape: (3, 3), constraints: Positive semi-definite, dims: ('state', 'state_aux')
Compiling...
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[69], line 17
     14 data_xreg = pm.MutableData("data_xreg", df[["x"]])
     16 ss_mod.build_statespace_graph(df[['y']], mode="JAX")
---> 17 idata = pm.sample(nuts_sampler="numpyro", target_accept=0.9)

File ~/miniconda3/envs/pymc_exp/lib/python3.11/site-packages/pymc/sampling/mcmc.py:696, in sample(draws, tune, chains, cores, random_seed, progressbar, step, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, model, **kwargs)
    692     if not isinstance(step, NUTS):
    693         raise ValueError(
    694             "Model can not be sampled with NUTS alone. Your model is probably not continuous."
    695         )
--> 696     return _sample_external_nuts(
    697         sampler=nuts_sampler,
    698         draws=draws,
    699         tune=tune,
    700         chains=chains,
    701         target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
    702         random_seed=random_seed,
    703         initvals=initvals,
    704         model=model,
    705         progressbar=progressbar,
    706         idata_kwargs=idata_kwargs,
    707         nuts_sampler_kwargs=nuts_sampler_kwargs,
    708         **kwargs,
    709     )
    711 if isinstance(step, list):
    712     step = CompoundStep(step)

File ~/miniconda3/envs/pymc_exp/lib/python3.11/site-packages/pymc/sampling/mcmc.py:350, in _sample_external_nuts(sampler, draws, tune, chains, target_accept, random_seed, initvals, model, progressbar, idata_kwargs, nuts_sampler_kwargs, **kwargs)
    347 elif sampler == "numpyro":
    348     import pymc.sampling.jax as pymc_jax
--> 350     idata = pymc_jax.sample_numpyro_nuts(
    351         draws=draws,
    352         tune=tune,
    353         chains=chains,
    354         target_accept=target_accept,
    355         random_seed=random_seed,
    356         initvals=initvals,
    357         model=model,
    358         progressbar=progressbar,
    359         idata_kwargs=idata_kwargs,
    360         **nuts_sampler_kwargs,
    361     )
    362     return idata
    364 elif sampler == "blackjax":

File ~/miniconda3/envs/pymc_exp/lib/python3.11/site-packages/pymc/sampling/jax.py:669, in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progressbar, keep_untransformed, chain_method, postprocessing_backend, postprocessing_vectorize, idata_kwargs, nuts_kwargs, postprocessing_chunks)
    660 logger.info("Compiling...")
    662 init_params = _get_batched_jittered_initial_points(
    663     model=model,
    664     chains=chains,
    665     initvals=initvals,
    666     random_seed=random_seed,
    667 )
--> 669 logp_fn = get_jaxified_logp(model, negative_logp=False)
    671 nuts_kwargs = _update_numpyro_nuts_kwargs(nuts_kwargs)
    672 nuts_kernel = NUTS(
    673     potential_fn=logp_fn,
    674     target_accept_prob=target_accept,
    675     **nuts_kwargs,
    676 )

File ~/miniconda3/envs/pymc_exp/lib/python3.11/site-packages/pymc/sampling/jax.py:151, in get_jaxified_logp(model, negative_logp)
    149 if not negative_logp:
    150     model_logp = -model_logp
--> 151 logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logp])
    153 def logp_fn_wrap(x):
    154     return logp_fn(*x)[0]

File ~/miniconda3/envs/pymc_exp/lib/python3.11/site-packages/pymc/sampling/jax.py:126, in get_jaxified_graph(inputs, outputs)
    120 def get_jaxified_graph(
    121     inputs: Optional[List[TensorVariable]] = None,
    122     outputs: Optional[List[TensorVariable]] = None,
    123 ) -> List[TensorVariable]:
    124     """Compile an PyTensor graph into an optimized JAX function"""
--> 126     graph = _replace_shared_variables(outputs) if outputs is not None else None
    128     fgraph = FunctionGraph(inputs=inputs, outputs=graph, clone=True)
    129     # We need to add a Supervisor to the fgraph to be able to run the
    130     # JAX sequential optimizer without warnings. We made sure there
    131     # are no mutable input variables, so we only need to check for
    132     # "destroyers". This should be automatically handled by PyTensor
    133     # once https://github.com/aesara-devs/aesara/issues/637 is fixed.

File ~/miniconda3/envs/pymc_exp/lib/python3.11/site-packages/pymc/sampling/jax.py:116, in _replace_shared_variables(graph)
    109     raise ValueError(
    110         "Graph contains shared variables with default_update which cannot "
    111         "be safely replaced."
    112     )
    114 replacements = {var: pt.constant(var.get_value(borrow=True)) for var in shared_variables}
--> 116 new_graph = clone_replace(graph, replace=replacements)
    117 return new_graph

File ~/miniconda3/envs/pymc_exp/lib/python3.11/site-packages/pytensor/graph/replace.py:87, in clone_replace(output, replace, **rebuild_kwds)
     84 _, _outs, _ = rebuild_collect_shared(output, [], tmp_replace, [], **rebuild_kwds)
     86 # TODO Explain why we call it twice ?!
---> 87 _, outs, _ = rebuild_collect_shared(_outs, [], new_replace, [], **rebuild_kwds)
     89 return outs

File ~/miniconda3/envs/pymc_exp/lib/python3.11/site-packages/pytensor/compile/function/pfunc.py:317, in rebuild_collect_shared(outputs, inputs, replace, updates, rebuild_strict, copy_inputs_over, no_default_updates, clone_inner_graphs)
    315 for v in outputs:
    316     if isinstance(v, Variable):
--> 317         cloned_v = clone_v_get_shared_updates(v, copy_inputs_over)
    318         cloned_outputs.append(cloned_v)
    319     elif isinstance(v, Out):

File ~/miniconda3/envs/pymc_exp/lib/python3.11/site-packages/pytensor/compile/function/pfunc.py:193, in rebuild_collect_shared.<locals>.clone_v_get_shared_updates(v, copy_inputs_over)
    191 if owner not in clone_d:
    192     for i in owner.inputs:
--> 193         clone_v_get_shared_updates(i, copy_inputs_over)
    194     clone_node_and_cache(
    195         owner,
    196         clone_d,
    197         strict=rebuild_strict,
    198         clone_inner_graphs=clone_inner_graphs,
    199     )
    200 return clone_d.setdefault(v, v)

File ~/miniconda3/envs/pymc_exp/lib/python3.11/site-packages/pytensor/compile/function/pfunc.py:193, in rebuild_collect_shared.<locals>.clone_v_get_shared_updates(v, copy_inputs_over)
    191 if owner not in clone_d:
    192     for i in owner.inputs:
--> 193         clone_v_get_shared_updates(i, copy_inputs_over)
    194     clone_node_and_cache(
    195         owner,
    196         clone_d,
    197         strict=rebuild_strict,
    198         clone_inner_graphs=clone_inner_graphs,
    199     )
    200 return clone_d.setdefault(v, v)

    [... skipping similar frames: rebuild_collect_shared.<locals>.clone_v_get_shared_updates at line 193 (3 times)]

File ~/miniconda3/envs/pymc_exp/lib/python3.11/site-packages/pytensor/compile/function/pfunc.py:193, in rebuild_collect_shared.<locals>.clone_v_get_shared_updates(v, copy_inputs_over)
    191 if owner not in clone_d:
    192     for i in owner.inputs:
--> 193         clone_v_get_shared_updates(i, copy_inputs_over)
    194     clone_node_and_cache(
    195         owner,
    196         clone_d,
    197         strict=rebuild_strict,
    198         clone_inner_graphs=clone_inner_graphs,
    199     )
    200 return clone_d.setdefault(v, v)

File ~/miniconda3/envs/pymc_exp/lib/python3.11/site-packages/pytensor/compile/function/pfunc.py:194, in rebuild_collect_shared.<locals>.clone_v_get_shared_updates(v, copy_inputs_over)
    192         for i in owner.inputs:
    193             clone_v_get_shared_updates(i, copy_inputs_over)
--> 194         clone_node_and_cache(
    195             owner,
    196             clone_d,
    197             strict=rebuild_strict,
    198             clone_inner_graphs=clone_inner_graphs,
    199         )
    200     return clone_d.setdefault(v, v)
    201 elif isinstance(v, SharedVariable):

File ~/miniconda3/envs/pymc_exp/lib/python3.11/site-packages/pytensor/graph/basic.py:1200, in clone_node_and_cache(node, clone_d, clone_inner_graphs, **kwargs)
   1196 new_op: Optional["Op"] = cast(Optional["Op"], clone_d.get(node.op))
   1198 cloned_inputs: list[Variable] = [cast(Variable, clone_d[i]) for i in node.inputs]
-> 1200 new_node = node.clone_with_new_inputs(
   1201     cloned_inputs,
   1202     # Only clone inner-graph `Op`s when there isn't a cached clone (and
   1203     # when `clone_inner_graphs` is enabled)
   1204     clone_inner_graph=clone_inner_graphs if new_op is None else False,
   1205     **kwargs,
   1206 )
   1208 if new_op:
   1209     # If we didn't clone the inner-graph `Op` above, because
   1210     # there was a cached version, set the cloned `Apply` to use
   1211     # the cached clone `Op`
   1212     new_node.op = new_op

File ~/miniconda3/envs/pymc_exp/lib/python3.11/site-packages/pytensor/graph/basic.py:283, in Apply.clone_with_new_inputs(self, inputs, strict, clone_inner_graph)
    280     if isinstance(new_op, HasInnerGraph) and clone_inner_graph:  # type: ignore
    281         new_op = new_op.clone()  # type: ignore
--> 283     new_node = new_op.make_node(*new_inputs)
    284     new_node.tag = copy(self.tag).__update__(new_node.tag)
    285 else:

File ~/miniconda3/envs/pymc_exp/lib/python3.11/site-packages/pytensor/scan/op.py:1195, in Scan.make_node(self, *inputs)
   1193     new_inputs.append(outer_nonseq)
   1194     if not outer_nonseq.type.in_same_class(inner_nonseq.type):
-> 1195         raise ValueError(
   1196             f"Argument {outer_nonseq} given to the scan node is not"
   1197             f" compatible with its corresponding loop function variable {inner_nonseq}"
   1198         )
   1200 for outer_nitsot in self.outer_nitsot(inputs):
   1201     # For every nit_sot input we get as input a int/uint that
   1202     # depicts the size in memory for that sequence. This feature is
   1203     # used by truncated BPTT and by scan space optimization
   1204     if (
   1205         str(outer_nitsot.type.dtype) not in integer_dtypes
   1206         or outer_nitsot.ndim != 0
   1207     ):

ValueError: Argument [[-0.85651 ... 33556989]] given to the scan node is not compatible with its corresponding loop function variable *4-<Matrix(float64, shape=(?, ?))>
ricardoV94 commented 6 months ago

CC @jessegrabowski

jessegrabowski commented 6 months ago

Hey, thanks for giving the module a try!

I can reproduce the problem on my end, so it's definitely a bug. It looks like a JAX problem; I can run your code if I use the default PyMC sampler. Also, if you have more than one regressor it works. For example, this runs:


import pandas as pd
import numpy as np

# Generate dummy data with monthly seasonality and trend
n_samples = 1000
k_exog = 3

np.random.seed(100)

trend_data = np.arange(n_samples) * .1
true_betas = np.random.normal(size=(k_exog,))
regressor_data = np.random.normal(scale=2, size=(n_samples, k_exog))
y = trend_data + regressor_data @ true_betas + np.random.normal(scale=2, size=n_samples) + 10
df = pd.DataFrame(np.c_[regressor_data, y],
                  index = pd.date_range("2001-01-01", freq="M", periods=n_samples),
                  columns=['y'] + [f'x_{i}' for i in range(k_exog)])
df.index.freq = 'M'

trend = st.LevelTrendComponent(name="linear_trend", order=2, innovations_order=0)
regressor = st.RegressionComponent(name="xreg", k_exog=k_exog, state_names=['x0', 'x1', 'x2'])
error = st.MeasurementError(name="error")

mod = trend + error + regressor
ss_mod = mod.build(name="test")
trend_dims, obs_dims, regressor_dims, regression_data_dims, P0_dims = ss_mod.param_dims.values()
coords = ss_mod.coords

with pm.Model(coords=coords) as model_1:
    data_xreg = pm.MutableData("data_xreg", df.drop(columns='y').values)

    P0_diag = pm.Gamma("P0_diag", alpha=2, beta=5, dims=P0_dims[0])
    P0 = pm.Deterministic("P0", pt.diag(P0_diag), dims=P0_dims)

    initial_trend = pm.Normal("initial_trend", dims=trend_dims)
    sigma_error = pm.Gamma("sigma_error", alpha=2, beta=5, dims=["observed_state"])

    beta_xreg = pm.Normal("beta_xreg", .2, 1, dims=regressor_dims)

    ss_mod.build_statespace_graph(df[['y']], mode='JAX')
    idata = pm.sample(nuts_sampler='numpyro', target_accept=0.9)
    # prior = pm.sample_prior_predictive(samples=10)

Probably the data is being incorrectly squeezed somewhere. I'll look closely and push a fix ASAP. Thanks for finding this bug and opening an issue!

jessegrabowski commented 4 months ago

I finally had some time to look closely at this. It appears to be a bug that arises because broadcastable dimensions are signified by a shape of 1 in pytensor. This means the program considers them dynamic, since they might change after broadcasting. As a result, JAX gets upset by the graph, because it doesn't allow dynamic shapes. This is why the model works if you have more than one exogenous variable -- the 2nd dimension of the exogenous data isn't 1 anymore, and everything is inferred to be static. Might be related to https://github.com/pymc-devs/pytensor/issues/408, but not sure.

For now, I can think of two possible work-arounds:

  1. Explicitly specify the shape of the exogenous data when you create the pm.MutableData, by passing a shape keyword argument.
  2. Use pm.ConstantData instead of pm.MutableData

Despite my choice of ordering, I think option 2 is preferable.

Here is a working example:

import pandas as pd
import numpy as np

# Generate dummy data with monthly seasonality and trend
n_samples = 1000
k_exog = 1

np.random.seed(100)

trend_data = np.arange(n_samples) * .1
true_betas = np.random.normal(size=(k_exog,))
regressor_data = np.random.normal(scale=2, size=(n_samples, k_exog))
y = trend_data + regressor_data @ true_betas + np.random.normal(scale=2, size=n_samples) + 10
df = pd.DataFrame(np.c_[y, regressor_data],
                  index = pd.date_range("2001-01-01", freq="ME", periods=n_samples),
                  columns=['y'] + [f'x_{i}' for i in range(k_exog)])
df.index.freq = 'ME'

trend = st.LevelTrendComponent(name="linear_trend", order=2, innovations_order=0)
regressor = st.RegressionComponent(name="xreg", k_exog=k_exog, state_names=[f'x{i}' for i in range(k_exog)])
error = st.MeasurementError(name="error")

mod = trend + error + regressor
ss_mod = mod.build(name="test")
trend_dims, obs_dims, regressor_dims, P0_dims = ss_mod.param_dims.values()
coords = ss_mod.coords

with pm.Model(coords=coords) as model_1:

    # Option 1:
    data_xreg = pm.MutableData("data_xreg", df.drop(columns='y').values, 
                               dims=['time', 'exog_state'],
                               shape=(n_samples, k_exog)) # <--- Key line

    # Option 2:
    # data_xreg = pm.ConstantData("data_xreg", df.drop(columns='y').values, 
    #                           dims=['time', 'exog_state'])

    P0_diag = pm.Gamma("P0_diag", alpha=2, beta=5, dims=P0_dims[0])
    P0 = pm.Deterministic("P0", pt.diag(P0_diag), dims=P0_dims)

    initial_trend = pm.Normal("initial_trend", dims=trend_dims)
    sigma_error = pm.Gamma("sigma_error", alpha=2, beta=5)

    beta_xreg = pm.Normal("beta_xreg", .2, 1, dims=regressor_dims)

    ss_mod.build_statespace_graph(df[['y']], mode='JAX')
    idata = pm.sample(nuts_sampler='numpyro', target_accept=0.9)

Note that I also specified n_samples. If you don't JAX will bark at you about dynamic shapes when you try to do post-estimation sampling (ss_mod.sample_conditional_posterior, for example).

I'll let you know when I come up with a more long-term solution.

ricardoV94 commented 4 months ago

You only need to specify the shape for broadcastable dims if you intend it to broadcast. You can pass shape=(None, 1) if you still want the other dim to be resizeable (but cannot broadcast it with other parameters)

jessegrabowski commented 4 months ago

Yes this works as well (with pm.MutableData), but JAX will still error on pm.sample_posterior_predictive, complaining about dynamic slicing. So I recommend to just declare both for now, since conditional forecasting with exogenous timeseries isn't support yet anyway.

jessegrabowski commented 2 months ago

This should be fixed by #326. Feel free to open a new issue if you're still hitting problems. I updated the structural example notebook, but it still needs more work. Still, it should give you an idea of how to include exogenous regressors.