Now, I can perform a single evaluation manually of the likelihood function and everything works ok. Similarly, the sanity check provided by jaxns works fine. But when I try to run the full sampler I get a type error
TypeError: body_fun output and input must have identical types, got
(UniDimProposalState(key='ShapedArray(uint32[2])', process_step='ShapedArray(int32[])', proposal_count='ShapedArray(int32[])', num_likelihood_evaluations='ShapedArray(int32[])', point_U0='DIFFERENT ShapedArray(float64[1]) vs. ShapedArray(float32[1])', log_L0='ShapedArray(float32[])', direction='ShapedArray(float64[])', left='ShapedArray(float64[])', right='ShapedArray(float64[])', point_U='ShapedArray(float64[1])', t='ShapedArray(float32[])', log_L_constraint='ShapedArray(float64[])'), 'ShapedArray(float32[])').
this is very similar to https://github.com/google/jax/discussions/5457 , which suggests that somewhere there is an attempted conversion from float32 to float64. At the initialisation of the script I enable float64 in jax and all my calculations are done in float64 - do I need to tell jaxns to work with float64 too?
First of all, thanks for making this package. It looks really great and has the potential to be a very useful tool!
I am playing around with jaxns, having come from using dynesty in the past.
I am trying to get a simple example working where I define some likelihood function which takes a single parameter, a prior on that parameter, and try to run the nested sampler. An example can be found here: https://github.com/tomkimpson/StateSpacePTA.jl/blob/jax/py_src/jax_ns_example.py
Now, I can perform a single evaluation manually of the likelihood function and everything works ok. Similarly, the sanity check provided by jaxns works fine. But when I try to run the full sampler I get a type error
this is very similar to https://github.com/google/jax/discussions/5457 , which suggests that somewhere there is an attempted conversion from float32 to float64. At the initialisation of the script I enable float64 in jax and all my calculations are done in float64 - do I need to tell jaxns to work with float64 too?
Thanks for your help, and kudos on the package!