Closed Hellobamboobamboo closed 2 months ago
The answer is provided in https://github.com/lnccbrown/HSSM/discussions/525. Closing this issue for now. The new version of HSSM will fix this problem more permanently. Please let us know if you have any questions.
Describe the bug A clear and concise description of what the bug is.
HSSM version 0.2.3
To Reproduce
import ssms.basic_simulators # Model simulators
import hddm_wfpt import hssm import jax import pytensor # Graph-based tensor library from matplotlib import pyplot as plt
Setting float precision in pytensor
pytensor.config.floatX = "float32" jax.config.update("jax_enable_x64", False)
Specify parameter values
v_true, a_true, z_true, t_true = [0.5, 1.5, 0.5, 0.2]
Simulate data
sim_out = simulator([v_true, a_true, z_true, t_true], model="ddm", n_samples=500)
Turn data into a pandas dataframe
dataset = pd.DataFrame( np.column_stack([sim_out["rts"][:, 0], sim_out["choices"][:, 0]]), columns=["rt", "response"], )
dataset
simple_ddm_model = hssm.HSSM(data=dataset) print(simple_ddm_model) simple_ddm_model.graph()
infer_data_simple_ddm_model = simple_ddm_model.sample( sampler="nuts_numpyro", # type of sampler to choose, 'nuts_numpyro', 'nuts_blackjax' of default pymc nuts sampler cores=1, # how many cores to use chains=2, # how many chains to run draws=500, # number of draws from the markov chain tune=500, # number of burn-in samples idata_kwargs=dict(log_likelihood=True), # return log likelihood ) # mp_ctx="forkserver")
type(infer_data_simple_ddm_model)
infer_data_simple_ddm_model
az.summary(infer_data_simple_ddm_model) az.plot_trace( infer_data_simple_ddm_model, var_names="~log_likelihood", # we exclude the log_likelihood traces here ) plt.tight_layout()
Screenshots
Additional context Doesn't produce plots with arviz
"WARNING (pytensor.configdefaults): g++ not available, if using conda:
conda install m2w64-toolchain
WARNING (pytensor.configdefaults): g++ not detected! PyTensor will be unable to compile C-implementations and will default to Python. Performance may be severely degraded. To remove this warning, set PyTensor flags cxx to an empty string. WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions. C:\Users\jc4472\AppData\Local\anaconda3\envs\hssm_env\Lib\site-packages\tqdm\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm Hierarchical Sequential Sampling Model Model: ddmResponse variable: rt,response Likelihood: analytical Observations: 500
Parameters:
v: Prior: Normal(mu: 0.0, sigma: 2.0) Explicit bounds: (-inf, inf)
a: Prior: HalfNormal(sigma: 2.0) Explicit bounds: (0.0, inf)
z: Prior: Uniform(lower: 0.0, upper: 1.0) Explicit bounds: (0.0, 1.0)
t: Prior: HalfNormal(sigma: 2.0) Explicit bounds: (0.0, inf)
Lapse probability: 0.05 Lapse distribution: Uniform(lower: 0.0, upper: 10.0) Traceback (most recent call last):
File ~\AppData\Local\anaconda3\envs\hssm_env\Lib\site-packages\spyder_kernels\py3compat.py:356 in compat_exec exec(code, globals, locals)
File c:\users\jc4472.spyder-py3\temp.py:61 simple_ddm_model.graph()
File ~\AppData\Local\anaconda3\envs\hssm_env\Lib\site-packages\hssm\hssm.py:930 in graph ).make_graph(formatting=formatting, response_str=self.response_str)
File ~\AppData\Local\anaconda3\envs\hssm_env\Lib\site-packages\hssm\utils.py:205 in make_graph for plate_label, all_var_names in self.get_plates(var_names).items():
AttributeError: 'list' object has no attribute 'items'"
The code I was running are:
import ssms.basic_simulators # Model simulators
import hddm_wfpt import hssm import jax import pytensor # Graph-based tensor library from matplotlib import pyplot as plt
Setting float precision in pytensor
pytensor.config.floatX = "float32" jax.config.update("jax_enable_x64", False)
Specify parameter values
v_true, a_true, z_true, t_true = [0.5, 1.5, 0.5, 0.2]
Simulate data
sim_out = simulator([v_true, a_true, z_true, t_true], model="ddm", n_samples=500)
Turn data into a pandas dataframe
dataset = pd.DataFrame( np.column_stack([sim_out["rts"][:, 0], sim_out["choices"][:, 0]]), columns=["rt", "response"], )
dataset
simple_ddm_model = hssm.HSSM(data=dataset) print(simple_ddm_model) simple_ddm_model.graph()
infer_data_simple_ddm_model = simple_ddm_model.sample( sampler="nuts_numpyro", # type of sampler to choose, 'nuts_numpyro', 'nuts_blackjax' of default pymc nuts sampler cores=1, # how many cores to use chains=2, # how many chains to run draws=500, # number of draws from the markov chain tune=500, # number of burn-in samples idata_kwargs=dict(log_likelihood=True), # return log likelihood ) # mp_ctx="forkserver")
type(infer_data_simple_ddm_model)
infer_data_simple_ddm_model
az.summary(infer_data_simple_ddm_model) az.plot_trace( infer_data_simple_ddm_model, var_names="~log_likelihood", # we exclude the log_likelihood traces here ) plt.tight_layout()