Joshuaalbert / jaxns

Probabilistic Programming and Nested sampling in JAX
https://jaxns.readthedocs.io/
Other
142 stars 10 forks source link

Running jaxns with Float64 #85

Closed tomkimpson closed 1 year ago

tomkimpson commented 1 year ago

First of all, thanks for making this package. It looks really great and has the potential to be a very useful tool!

I am playing around with jaxns, having come from using dynesty in the past.

I am trying to get a simple example working where I define some likelihood function which takes a single parameter, a prior on that parameter, and try to run the nested sampler. An example can be found here: https://github.com/tomkimpson/StateSpacePTA.jl/blob/jax/py_src/jax_ns_example.py

Now, I can perform a single evaluation manually of the likelihood function and everything works ok. Similarly, the sanity check provided by jaxns works fine. But when I try to run the full sampler I get a type error

TypeError: body_fun output and input must have identical types, got
(UniDimProposalState(key='ShapedArray(uint32[2])', process_step='ShapedArray(int32[])', proposal_count='ShapedArray(int32[])', num_likelihood_evaluations='ShapedArray(int32[])', point_U0='DIFFERENT ShapedArray(float64[1]) vs. ShapedArray(float32[1])', log_L0='ShapedArray(float32[])', direction='ShapedArray(float64[])', left='ShapedArray(float64[])', right='ShapedArray(float64[])', point_U='ShapedArray(float64[1])', t='ShapedArray(float32[])', log_L_constraint='ShapedArray(float64[])'), 'ShapedArray(float32[])').

this is very similar to https://github.com/google/jax/discussions/5457 , which suggests that somewhere there is an attempted conversion from float32 to float64. At the initialisation of the script I enable float64 in jax and all my calculations are done in float64 - do I need to tell jaxns to work with float64 too?

Thanks for your help, and kudos on the package!

Joshuaalbert commented 1 year ago

Thanks for the kind words!

What if you put config.update("jax_enable_x64", True) at the very top of your script?

tomkimpson commented 1 year ago

Hmm I have that already

Joshuaalbert commented 1 year ago

Try placing it above all other imports.

tomkimpson commented 1 year ago

D'oh! Thanks for the assist!