Closed juanitorduz closed 1 year ago
I am not able to use pm.sampling_jax.sample_numpyro_nuts in pymc 5.0.0 >= with a GaussianRandomWalk. I can no longer replicate the model presented in https://juanitorduz.github.io/bikes_pymc/.
pm.sampling_jax.sample_numpyro_nuts
GaussianRandomWalk
import numpy as np import pandas as pd import pymc as pm import pymc.sampling_jax from pymc.distributions.continuous import Exponential from sklearn.preprocessing import StandardScaler import seaborn as sns data_path = "https://raw.githubusercontent.com/christophM/interpretable-ml-book/master/data/bike.csv" raw_data_df = pd.read_csv(data_path) data_df = raw_data_df.copy().assign( date=pd.date_range( start=pd.to_datetime("2011-01-01"), end=pd.to_datetime("2012-12-31"), freq="D" ) ) target = "cnt" target_scaled = f"{target}_scaled" endog_scaler = StandardScaler() exog_scaler = StandardScaler() data_df[target_scaled] = endog_scaler.fit_transform(X=data_df[[target]]) data_df[["temp_scaled", "hum_scaled", "windspeed_scaled"]] = exog_scaler.fit_transform( X=data_df[["temp", "hum", "windspeed"]] ) n = data_df.shape[0] # target cnt = data_df[target].to_numpy() cnt_scaled = data_df[target_scaled].to_numpy() # date feature date = data_df["date"].to_numpy() # model regressors temp_scaled = data_df["temp_scaled"].to_numpy() hum_scaled = data_df["hum_scaled"].to_numpy() windspeed_scaled = data_df["windspeed_scaled"].to_numpy() holiday_idx, holiday = data_df["holiday"].factorize(sort=True) workingday_idx, workingday = data_df["workingday"].factorize(sort=True) weathersit_idx, weathersit = data_df["weathersit"].factorize(sort=True) t = data_df["days_since_2011"].to_numpy() / data_df["days_since_2011"].max() coords = { "date": date, "workingday": workingday, "weathersit": weathersit, } with pm.Model(coords=coords) as model: # --- priors --- intercept = pm.Normal(name="intercept", mu=0, sigma=2) b_hum = pm.Normal(name="b_hum", mu=0, sigma=2) b_windspeed = pm.Normal(name="b_windspeed", mu=0, sigma=2) b_workingday = pm.Normal(name="b_workingday", mu=0, sigma=2, dims="workingday") b_weathersit = pm.Normal(name="b_weathersit", mu=0, sigma=2, dims="weathersit") b_t = pm.Normal(name="b_t", mu=0, sigma=3) sigma_slopes = pm.HalfNormal(name="sigma_slope", sigma=0.5) nu = pm.Gamma(name="nu", alpha=8, beta=2) sigma = pm.HalfNormal(name="sigma", sigma=2) # --- model parametrization --- slopes = pm.GaussianRandomWalk( name="slopes", sigma=sigma_slopes, init_dist=Exponential.dist(lam=1.0), dims="date", ) mu = pm.Deterministic( name="mu", var=( intercept + b_t * t + slopes * temp_scaled + b_hum * hum_scaled + b_windspeed * windspeed_scaled + b_workingday[workingday_idx] + b_weathersit[weathersit_idx] ), dims="date", ) # --- likelihood --- likelihood = pm.StudentT( name="likelihood", mu=mu, nu=nu, sigma=sigma, dims="date", observed=cnt_scaled ) with model: idata = pm.sampling_jax.sample_numpyro_nuts( target_accept=0.8, draws=1000, chains=4 ) posterior_predictive = pm.sample_posterior_predictive(trace=idata)
NotImplementedError Traceback (most recent call last) Cell In [1], line 95 89 likelihood = pm.StudentT( 90 name="likelihood", mu=mu, nu=nu, sigma=sigma, dims="date", observed=cnt_scaled 91 ) 94 with model: ---> 95 idata = pm.sampling_jax.sample_numpyro_nuts( 96 target_accept=0.8, draws=1000, chains=4 97 ) 98 posterior_predictive = pm.sample_posterior_predictive(trace=idata) File ~/opt/anaconda3/envs/pymmmc_env/lib/python3.8/site-packages/pymc/sampling/jax.py:614, in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progressbar, keep_untransformed, chain_method, postprocessing_backend, postprocessing_chunks, idata_kwargs, nuts_kwargs) 605 print("Compiling...", file=sys.stdout) 607 init_params = _get_batched_jittered_initial_points( 608 model=model, 609 chains=chains, 610 initvals=initvals, 611 random_seed=random_seed, 612 ) --> 614 logp_fn = get_jaxified_logp(model, negative_logp=False) 616 nuts_kwargs = _update_numpyro_nuts_kwargs(nuts_kwargs) 617 nuts_kernel = NUTS( 618 potential_fn=logp_fn, 619 target_accept_prob=target_accept, 620 **nuts_kwargs, 621 ) File ~/opt/anaconda3/envs/pymmmc_env/lib/python3.8/site-packages/pymc/sampling/jax.py:118, in get_jaxified_logp(model, negative_logp) 116 if not negative_logp: 117 model_logp = -model_logp --> 118 logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logp]) 120 def logp_fn_wrap(x): 121 return logp_fn(*x)[0] File ~/opt/anaconda3/envs/pymmmc_env/lib/python3.8/site-packages/pymc/sampling/jax.py:111, in get_jaxified_graph(inputs, outputs) 108 mode.JAX.optimizer.rewrite(fgraph) 110 # We now jaxify the optimized fgraph --> 111 return jax_funcify(fgraph) File ~/opt/anaconda3/envs/pymmmc_env/lib/python3.8/functools.py:875, in singledispatch.<locals>.wrapper(*args, **kw) 871 if not args: 872 raise TypeError(f'{funcname} requires at least ' 873 '1 positional argument') --> 875 return dispatch(args[0].__class__)(*args, **kw) ... 37 def jax_funcify(op, node=None, storage_map=None, **kwargs): 38 """Create a JAX compatible function from an PyTensor `Op`.""" ---> 39 raise NotImplementedError(f"No JAX conversion for the given `Op`: {op}") NotImplementedError: No JAX conversion for the given `Op`: Split{2}
import pymc import pytensor import jax print (pymc.__version__) - > 5.0.2 print (pytensor.__version__) -> 2.9.1 print (jax.__version__) -> 3.24
No response
It's related to having no JAX conversion of the Split Op. There is a tracking issue on PyTensor already, so I'll close this one here: https://github.com/pymc-devs/pytensor/issues/145
Thanks @ricardoV94 :)
Describe the issue:
I am not able to use
pm.sampling_jax.sample_numpyro_nuts
in pymc 5.0.0 >= with aGaussianRandomWalk
. I can no longer replicate the model presented in https://juanitorduz.github.io/bikes_pymc/.Reproduceable code example:
Error message:
PyMC version information:
Context for the issue:
No response