lnccbrown / HSSM

Development of HSSM package
Other
71 stars 10 forks source link

jax & jaxlib need to be downgraded #352

Closed igrahek closed 3 months ago

igrahek commented 4 months ago

The version of jax & jaxlib installed with HSSM v0.2.0 is 0.4.24

When running this code:

cav_data = hssm.load_data("cavanagh_theta")

model_safe = hssm.HSSM(
    data=cav_data,
    hierarchical=True,
    prior_settings="safe",
    loglik_kind="approx_differentiable",
)
model_safe.sample()

I get the error: ModuleNotFoundError: No module named 'jax.linear_util'

Downgrading jax & jaxlib to 0.4.23 solves the issue.

digicosmos86 commented 4 months ago

Hi @igrahek!

This is a good point. We will add version constraints for more stable jax versions