Closed gingjehli closed 1 month ago
Hi @gingjehli,
The latest version of JAX has deprecated the use of "jax.config", which was causing some slight incompatibility issues. We have fixed the issue in #443. Please install HSSM again from the repo and the issue should be fixed
For hssm=0.2.2 only, I'm getting the following error message when trying to set:
hssm.set_floatX("float32")
This is my inputted code:
` import numpy as np import matplotlib as plt import pandas as pd
import hddm_wfpt import matplotlib as plt import matplotlib.pyplot as pltplot
plt.use('Agg') import arviz as az import bambi as bmb
import pathlib import pytensor
import hssm import ssms from ssms.basic_simulators import simulator
import jax
from jax.config import config
config.update("jax_enable_x64", False)
%matplotlib inline %config InlineBackend.figure_format='retina'
hssm.set_floatX("float32")
`
this is the error message:
` Setting PyTensor floatX type to float32.
AttributeError Traceback (most recent call last) Cell In[94], line 32 29 get_ipython().run_line_magic('matplotlib', 'inline') 30 get_ipython().run_line_magic('config', "InlineBackend.figure_format='retina'") ---> 32 hssm.set_floatX("float32")
File ~/.conda/envs/pyHSSM3/lib/python3.11/site-packages/hssm/utils.py:283, in set_floatX(dtype, update_jax) 281 if update_jax: 282 jax_enable_x64 = dtype == "float64" --> 283 jax.config.update("jax_enable_x64", jax_enable_x64) 285 _logger.info( 286 'Setting "jax_enable_x64" to %s. ' 287 + "If this is not intended, please set
jax
to False.", 288 jax_enable_x64, 289 )AttributeError: module 'jax.config' has no attribute 'update'`
note that I get this error message irrespective of whether I'm importing jax. Also, I don't get this error message for hssm=0.2.0 or hssm=0.2.1. It's seems specific to version: hssm=0.2.2