lnccbrown / HSSM

Development of HSSM package
Other
77 stars 11 forks source link

arviz doesn't work or other problems? #524

Closed Hellobamboobamboo closed 2 months ago

Hellobamboobamboo commented 2 months ago

Describe the bug A clear and concise description of what the bug is.

HSSM version 0.2.3

To Reproduce

import ssms.basic_simulators # Model simulators

import hddm_wfpt import hssm import jax import pytensor # Graph-based tensor library from matplotlib import pyplot as plt

Setting float precision in pytensor

pytensor.config.floatX = "float32" jax.config.update("jax_enable_x64", False)

Specify parameter values

v_true, a_true, z_true, t_true = [0.5, 1.5, 0.5, 0.2]

Simulate data

sim_out = simulator([v_true, a_true, z_true, t_true], model="ddm", n_samples=500)

Turn data into a pandas dataframe

dataset = pd.DataFrame( np.column_stack([sim_out["rts"][:, 0], sim_out["choices"][:, 0]]), columns=["rt", "response"], )

dataset

simple_ddm_model = hssm.HSSM(data=dataset) print(simple_ddm_model) simple_ddm_model.graph()

infer_data_simple_ddm_model = simple_ddm_model.sample( sampler="nuts_numpyro", # type of sampler to choose, 'nuts_numpyro', 'nuts_blackjax' of default pymc nuts sampler cores=1, # how many cores to use chains=2, # how many chains to run draws=500, # number of draws from the markov chain tune=500, # number of burn-in samples idata_kwargs=dict(log_likelihood=True), # return log likelihood ) # mp_ctx="forkserver")

type(infer_data_simple_ddm_model)

infer_data_simple_ddm_model

az.summary(infer_data_simple_ddm_model) az.plot_trace( infer_data_simple_ddm_model, var_names="~log_likelihood", # we exclude the log_likelihood traces here ) plt.tight_layout()

Screenshots Screenshot (4) Screenshot (5)

Additional context Doesn't produce plots with arviz

"WARNING (pytensor.configdefaults): g++ not available, if using conda: conda install m2w64-toolchain WARNING (pytensor.configdefaults): g++ not detected! PyTensor will be unable to compile C-implementations and will default to Python. Performance may be severely degraded. To remove this warning, set PyTensor flags cxx to an empty string. WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions. C:\Users\jc4472\AppData\Local\anaconda3\envs\hssm_env\Lib\site-packages\tqdm\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm Hierarchical Sequential Sampling Model Model: ddm

Response variable: rt,response Likelihood: analytical Observations: 500

Parameters:

v: Prior: Normal(mu: 0.0, sigma: 2.0) Explicit bounds: (-inf, inf)

a: Prior: HalfNormal(sigma: 2.0) Explicit bounds: (0.0, inf)

z: Prior: Uniform(lower: 0.0, upper: 1.0) Explicit bounds: (0.0, 1.0)

t: Prior: HalfNormal(sigma: 2.0) Explicit bounds: (0.0, inf)

Lapse probability: 0.05 Lapse distribution: Uniform(lower: 0.0, upper: 10.0) Traceback (most recent call last):

File ~\AppData\Local\anaconda3\envs\hssm_env\Lib\site-packages\spyder_kernels\py3compat.py:356 in compat_exec exec(code, globals, locals)

File c:\users\jc4472.spyder-py3\temp.py:61 simple_ddm_model.graph()

File ~\AppData\Local\anaconda3\envs\hssm_env\Lib\site-packages\hssm\hssm.py:930 in graph ).make_graph(formatting=formatting, response_str=self.response_str)

File ~\AppData\Local\anaconda3\envs\hssm_env\Lib\site-packages\hssm\utils.py:205 in make_graph for plate_label, all_var_names in self.get_plates(var_names).items():

AttributeError: 'list' object has no attribute 'items'"

The code I was running are:

import ssms.basic_simulators # Model simulators

import hddm_wfpt import hssm import jax import pytensor # Graph-based tensor library from matplotlib import pyplot as plt

Setting float precision in pytensor

pytensor.config.floatX = "float32" jax.config.update("jax_enable_x64", False)

Specify parameter values

v_true, a_true, z_true, t_true = [0.5, 1.5, 0.5, 0.2]

Simulate data

sim_out = simulator([v_true, a_true, z_true, t_true], model="ddm", n_samples=500)

Turn data into a pandas dataframe

dataset = pd.DataFrame( np.column_stack([sim_out["rts"][:, 0], sim_out["choices"][:, 0]]), columns=["rt", "response"], )

dataset

simple_ddm_model = hssm.HSSM(data=dataset) print(simple_ddm_model) simple_ddm_model.graph()

infer_data_simple_ddm_model = simple_ddm_model.sample( sampler="nuts_numpyro", # type of sampler to choose, 'nuts_numpyro', 'nuts_blackjax' of default pymc nuts sampler cores=1, # how many cores to use chains=2, # how many chains to run draws=500, # number of draws from the markov chain tune=500, # number of burn-in samples idata_kwargs=dict(log_likelihood=True), # return log likelihood ) # mp_ctx="forkserver")

type(infer_data_simple_ddm_model)

infer_data_simple_ddm_model

az.summary(infer_data_simple_ddm_model) az.plot_trace( infer_data_simple_ddm_model, var_names="~log_likelihood", # we exclude the log_likelihood traces here ) plt.tight_layout()

digicosmos86 commented 2 months ago

The answer is provided in https://github.com/lnccbrown/HSSM/discussions/525. Closing this issue for now. The new version of HSSM will fix this problem more permanently. Please let us know if you have any questions.