pymc-devs / pymc

Bayesian Modeling and Probabilistic Programming in Python
https://docs.pymc.io/
Other
8.63k stars 1.99k forks source link

BUG: Can not use JAX samplers with `GaussianRandomWalk` in pymc 5 (but was ok in pymc 4) #6467

Closed juanitorduz closed 1 year ago

juanitorduz commented 1 year ago

Describe the issue:

I am not able to use pm.sampling_jax.sample_numpyro_nuts in pymc 5.0.0 >= with a GaussianRandomWalk. I can no longer replicate the model presented in https://juanitorduz.github.io/bikes_pymc/.

Reproduceable code example:

import numpy as np
import pandas as pd
import pymc as pm
import pymc.sampling_jax
from pymc.distributions.continuous import Exponential
from sklearn.preprocessing import StandardScaler
import seaborn as sns

data_path = "https://raw.githubusercontent.com/christophM/interpretable-ml-book/master/data/bike.csv"

raw_data_df = pd.read_csv(data_path)

data_df = raw_data_df.copy().assign(
    date=pd.date_range(
        start=pd.to_datetime("2011-01-01"), end=pd.to_datetime("2012-12-31"), freq="D"
    )
)

target = "cnt"
target_scaled = f"{target}_scaled"

endog_scaler = StandardScaler()
exog_scaler = StandardScaler()

data_df[target_scaled] = endog_scaler.fit_transform(X=data_df[[target]])
data_df[["temp_scaled", "hum_scaled", "windspeed_scaled"]] = exog_scaler.fit_transform(
    X=data_df[["temp", "hum", "windspeed"]]
)

n = data_df.shape[0]
# target
cnt = data_df[target].to_numpy()
cnt_scaled = data_df[target_scaled].to_numpy()
# date feature
date = data_df["date"].to_numpy()
# model regressors
temp_scaled = data_df["temp_scaled"].to_numpy()
hum_scaled = data_df["hum_scaled"].to_numpy()
windspeed_scaled = data_df["windspeed_scaled"].to_numpy()
holiday_idx, holiday = data_df["holiday"].factorize(sort=True)
workingday_idx, workingday = data_df["workingday"].factorize(sort=True)
weathersit_idx, weathersit = data_df["weathersit"].factorize(sort=True)
t = data_df["days_since_2011"].to_numpy() / data_df["days_since_2011"].max()

coords = {
    "date": date,
    "workingday": workingday,
    "weathersit": weathersit,
}

with pm.Model(coords=coords) as model:

    # --- priors ---
    intercept = pm.Normal(name="intercept", mu=0, sigma=2)
    b_hum = pm.Normal(name="b_hum", mu=0, sigma=2)
    b_windspeed = pm.Normal(name="b_windspeed", mu=0, sigma=2)
    b_workingday = pm.Normal(name="b_workingday", mu=0, sigma=2, dims="workingday")
    b_weathersit = pm.Normal(name="b_weathersit", mu=0, sigma=2, dims="weathersit")
    b_t = pm.Normal(name="b_t", mu=0, sigma=3)
    sigma_slopes = pm.HalfNormal(name="sigma_slope", sigma=0.5)
    nu = pm.Gamma(name="nu", alpha=8, beta=2)
    sigma = pm.HalfNormal(name="sigma", sigma=2)

    # --- model parametrization ---
    slopes = pm.GaussianRandomWalk(
        name="slopes",
        sigma=sigma_slopes,
        init_dist=Exponential.dist(lam=1.0),
        dims="date",
    )
    mu = pm.Deterministic(
        name="mu",
        var=(
            intercept
            + b_t * t
            + slopes * temp_scaled
            + b_hum * hum_scaled
            + b_windspeed * windspeed_scaled
            + b_workingday[workingday_idx]
            + b_weathersit[weathersit_idx]
        ),
        dims="date",
    )

    # --- likelihood ---
    likelihood = pm.StudentT(
        name="likelihood", mu=mu, nu=nu, sigma=sigma, dims="date", observed=cnt_scaled
    )

with model:
    idata = pm.sampling_jax.sample_numpyro_nuts(
        target_accept=0.8, draws=1000, chains=4
    )
    posterior_predictive = pm.sample_posterior_predictive(trace=idata)

Error message:

NotImplementedError                       Traceback (most recent call last)
Cell In [1], line 95
     89     likelihood = pm.StudentT(
     90         name="likelihood", mu=mu, nu=nu, sigma=sigma, dims="date", observed=cnt_scaled
     91     )
     94 with model:
---> 95     idata = pm.sampling_jax.sample_numpyro_nuts(
     96         target_accept=0.8, draws=1000, chains=4
     97     )
     98     posterior_predictive = pm.sample_posterior_predictive(trace=idata)

File ~/opt/anaconda3/envs/pymmmc_env/lib/python3.8/site-packages/pymc/sampling/jax.py:614, in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progressbar, keep_untransformed, chain_method, postprocessing_backend, postprocessing_chunks, idata_kwargs, nuts_kwargs)
    605 print("Compiling...", file=sys.stdout)
    607 init_params = _get_batched_jittered_initial_points(
    608     model=model,
    609     chains=chains,
    610     initvals=initvals,
    611     random_seed=random_seed,
    612 )
--> 614 logp_fn = get_jaxified_logp(model, negative_logp=False)
    616 nuts_kwargs = _update_numpyro_nuts_kwargs(nuts_kwargs)
    617 nuts_kernel = NUTS(
    618     potential_fn=logp_fn,
    619     target_accept_prob=target_accept,
    620     **nuts_kwargs,
    621 )

File ~/opt/anaconda3/envs/pymmmc_env/lib/python3.8/site-packages/pymc/sampling/jax.py:118, in get_jaxified_logp(model, negative_logp)
    116 if not negative_logp:
    117     model_logp = -model_logp
--> 118 logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logp])
    120 def logp_fn_wrap(x):
    121     return logp_fn(*x)[0]

File ~/opt/anaconda3/envs/pymmmc_env/lib/python3.8/site-packages/pymc/sampling/jax.py:111, in get_jaxified_graph(inputs, outputs)
    108 mode.JAX.optimizer.rewrite(fgraph)
    110 # We now jaxify the optimized fgraph
--> 111 return jax_funcify(fgraph)

File ~/opt/anaconda3/envs/pymmmc_env/lib/python3.8/functools.py:875, in singledispatch.<locals>.wrapper(*args, **kw)
    871 if not args:
    872     raise TypeError(f'{funcname} requires at least '
    873                     '1 positional argument')
--> 875 return dispatch(args[0].__class__)(*args, **kw)
...
     37 def jax_funcify(op, node=None, storage_map=None, **kwargs):
     38     """Create a JAX compatible function from an PyTensor `Op`."""
---> 39     raise NotImplementedError(f"No JAX conversion for the given `Op`: {op}")

NotImplementedError: No JAX conversion for the given `Op`: Split{2}

PyMC version information:

import pymc
import pytensor
import jax

print (pymc.__version__)  - > 5.0.2
print (pytensor.__version__) -> 2.9.1
print (jax.__version__) -> 3.24

Context for the issue:

No response

ricardoV94 commented 1 year ago

It's related to having no JAX conversion of the Split Op. There is a tracking issue on PyTensor already, so I'll close this one here: https://github.com/pymc-devs/pytensor/issues/145

juanitorduz commented 1 year ago

Thanks @ricardoV94 :)