Closed dirknbr closed 3 months ago
I might need more details on versions of things and the entire stack trace, but as a guess it looks like this issue.
The fix would be to bump the version of jax
you're running to past 0.4.27 (pip install -U "jax>=0.4.27"
, I think should do it)
thanks, that has indeed fixed the error
This code ran fine some months ago but now fails, this was in a cpu colab
idata = bx_model.mcmc.blackjax_nuts(seed=jax.random.key(0), num_draws=1000, num_chains=2)