lnccbrown / HSSM

Development of HSSM package
Other
70 stars 10 forks source link

error message for set_floatX for version: hssm=0.2.2 #451

Closed gingjehli closed 1 month ago

gingjehli commented 1 month ago

For hssm=0.2.2 only, I'm getting the following error message when trying to set: hssm.set_floatX("float32")

This is my inputted code:

` import numpy as np import matplotlib as plt import pandas as pd

import hddm_wfpt import matplotlib as plt import matplotlib.pyplot as pltplot

plt.use('Agg') import arviz as az import bambi as bmb

import pathlib import pytensor

import hssm import ssms from ssms.basic_simulators import simulator

import jax

from jax.config import config

config.update("jax_enable_x64", False)

%matplotlib inline %config InlineBackend.figure_format='retina'

hssm.set_floatX("float32")

`

this is the error message:

` Setting PyTensor floatX type to float32.

AttributeError Traceback (most recent call last) Cell In[94], line 32 29 get_ipython().run_line_magic('matplotlib', 'inline') 30 get_ipython().run_line_magic('config', "InlineBackend.figure_format='retina'") ---> 32 hssm.set_floatX("float32")

File ~/.conda/envs/pyHSSM3/lib/python3.11/site-packages/hssm/utils.py:283, in set_floatX(dtype, update_jax) 281 if update_jax: 282 jax_enable_x64 = dtype == "float64" --> 283 jax.config.update("jax_enable_x64", jax_enable_x64) 285 _logger.info( 286 'Setting "jax_enable_x64" to %s. ' 287 + "If this is not intended, please set jax to False.", 288 jax_enable_x64, 289 )

AttributeError: module 'jax.config' has no attribute 'update'`

note that I get this error message irrespective of whether I'm importing jax. Also, I don't get this error message for hssm=0.2.0 or hssm=0.2.1. It's seems specific to version: hssm=0.2.2

digicosmos86 commented 1 month ago

Hi @gingjehli,

The latest version of JAX has deprecated the use of "jax.config", which was causing some slight incompatibility issues. We have fixed the issue in #443. Please install HSSM again from the repo and the issue should be fixed