Closed fbartolic closed 2 years ago
Hi @fbartolic can you try pip install --ignore-installed jaxns==1.1.0
and see if you still get the problem. I was not able to reproduce. Also, just checking but is the config.update("jax_enable_x64", True)
at the very top?
Also, just checking but is the config.update("jax_enable_x64", True) at the very top?
Nope, it was after the jaxns
import. Placing it before the import solves the problem :). Thank you!
I think we might be able to resolve this even if it comes after import jaxns
. It looks like on num_likelihood_evaluations
is having a problem, likely due to an uncasted increment somewhere. I will reopen and see if I can track it down.
@fbartolic I pushed a quick fix, just looking for where num_likelihood_evaluations
is incremented where type is not controlled and explicitly set the correct type. Maybe you could try seeing if you get the same problem with it like like,
import jaxns
config.update("jax_enable_x64", True)
I believe this is resolved now, even when jaxns
is imported above the config statement. If someone comes across it again please reopen.
To reproduce, add
in the example
mvn_data_mvn_prior.ipynb
.I'm getting the following error:
EDIT: I'm also seeing the same issue in the numpyro version of jaxns.
jaxns version: http://github.com/Joshuaalbert/jaxns.git@201f78a0bb2d326315d4a3772d6b5f5f534f1ceb jax version: 0.3.10 python 3.9.12