Joshuaalbert / jaxns

Probabilistic Programming and Nested sampling in JAX
https://jaxns.readthedocs.io/
Apache License 2.0
145 stars 10 forks source link

Examples fail to run in 64bit mode #60

Closed fbartolic closed 2 years ago

fbartolic commented 2 years ago

To reproduce, add

from jax import config
config.update("jax_enable_x64", True)

in the example mvn_data_mvn_prior.ipynb.

I'm getting the following error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/Users/fb90/Documents/opt/jaxns/examples/simple/mvn_data_mvn_prior.ipynb Cell 6' in <cell line: 6>()
      [4](vscode-notebook-cell:/Users/fb90/Documents/opt/jaxns/examples/simple/mvn_data_mvn_prior.ipynb#ch0000005?line=3) # jit compile
      [5](vscode-notebook-cell:/Users/fb90/Documents/opt/jaxns/examples/simple/mvn_data_mvn_prior.ipynb#ch0000005?line=4) ns = jit(ns)
----> [6](vscode-notebook-cell:/Users/fb90/Documents/opt/jaxns/examples/simple/mvn_data_mvn_prior.ipynb#ch0000005?line=5) results = ns(random.PRNGKey(4525325), adaptive_evidence_patience=2)

    [... skipping hidden 14 frame]

File ~/miniforge3/envs/caustics/lib/python3.9/site-packages/jaxns/nested_sampler/nested_sampler.py:769, in NestedSampler.__call__(self, key, termination_ess, termination_evidence_uncert, termination_live_evidence_frac, termination_max_num_steps, termination_max_samples, termination_max_num_likelihood_evaluations, adaptive_evidence_patience, adaptive_evidence_stopping_threshold, G, num_live_points, return_state, refine_state)
    [766](file:///Users/fb90/miniforge3/envs/caustics/lib/python3.9/site-packages/jaxns/nested_sampler/nested_sampler.py?line=765)     assert termination_live_evidence_frac is not None, \
    [767](file:///Users/fb90/miniforge3/envs/caustics/lib/python3.9/site-packages/jaxns/nested_sampler/nested_sampler.py?line=766)         "Need the static stopping condition"
    [768](file:///Users/fb90/miniforge3/envs/caustics/lib/python3.9/site-packages/jaxns/nested_sampler/nested_sampler.py?line=767)     # TODO: maybe turn off the other termination criteria since they are really for dynamic
--> [769](file:///Users/fb90/miniforge3/envs/caustics/lib/python3.9/site-packages/jaxns/nested_sampler/nested_sampler.py?line=768)     state = self._new_static_loop(
    [770](file:///Users/fb90/miniforge3/envs/caustics/lib/python3.9/site-packages/jaxns/nested_sampler/nested_sampler.py?line=769)         init_state=state,
    [771](file:///Users/fb90/miniforge3/envs/caustics/lib/python3.9/site-packages/jaxns/nested_sampler/nested_sampler.py?line=770)         num_slices=self.sampler_kwargs.get('min_num_slices'),
    [772](file:///Users/fb90/miniforge3/envs/caustics/lib/python3.9/site-packages/jaxns/nested_sampler/nested_sampler.py?line=771)         static_num_live_points=num_live_points,
    [773](file:///Users/fb90/miniforge3/envs/caustics/lib/python3.9/site-packages/jaxns/nested_sampler/nested_sampler.py?line=772)         num_parallel_samplers=self.num_parallel_samplers,
    [774](file:///Users/fb90/miniforge3/envs/caustics/lib/python3.9/site-packages/jaxns/nested_sampler/nested_sampler.py?line=773)         termination_ess=termination_ess,
    [775](file:///Users/fb90/miniforge3/envs/caustics/lib/python3.9/site-packages/jaxns/nested_sampler/nested_sampler.py?line=774)         termination_evidence_uncert=termination_evidence_uncert,
    [776](file:///Users/fb90/miniforge3/envs/caustics/lib/python3.9/site-packages/jaxns/nested_sampler/nested_sampler.py?line=775)         termination_live_evidence_frac=termination_live_evidence_frac,
    [777](file:///Users/fb90/miniforge3/envs/caustics/lib/python3.9/site-packages/jaxns/nested_sampler/nested_sampler.py?line=776)         termination_max_num_steps=termination_max_num_steps,
    [778](file:///Users/fb90/miniforge3/envs/caustics/lib/python3.9/site-packages/jaxns/nested_sampler/nested_sampler.py?line=777)         termination_max_samples=termination_max_samples,
    [779](file:///Users/fb90/miniforge3/envs/caustics/lib/python3.9/site-packages/jaxns/nested_sampler/nested_sampler.py?line=778)         termination_max_num_likelihood_evaluations=termination_max_num_likelihood_evaluations,
    [780](file:///Users/fb90/miniforge3/envs/caustics/lib/python3.9/site-packages/jaxns/nested_sampler/nested_sampler.py?line=779)         termination_likelihood_contour=None,
    [781](file:///Users/fb90/miniforge3/envs/caustics/lib/python3.9/site-packages/jaxns/nested_sampler/nested_sampler.py?line=780)         log_L_constraint=None)
    [782](file:///Users/fb90/miniforge3/envs/caustics/lib/python3.9/site-packages/jaxns/nested_sampler/nested_sampler.py?line=781) if self.dynamic:
    [783](file:///Users/fb90/miniforge3/envs/caustics/lib/python3.9/site-packages/jaxns/nested_sampler/nested_sampler.py?line=782)     if not any([termination_ess is not None,
    [784](file:///Users/fb90/miniforge3/envs/caustics/lib/python3.9/site-packages/jaxns/nested_sampler/nested_sampler.py?line=783)                 termination_evidence_uncert is not None,
    [785](file:///Users/fb90/miniforge3/envs/caustics/lib/python3.9/site-packages/jaxns/nested_sampler/nested_sampler.py?line=784)                 termination_max_num_steps is not None,
    [786](file:///Users/fb90/miniforge3/envs/caustics/lib/python3.9/site-packages/jaxns/nested_sampler/nested_sampler.py?line=785)                 termination_max_samples is not None,
    [787](file:///Users/fb90/miniforge3/envs/caustics/lib/python3.9/site-packages/jaxns/nested_sampler/nested_sampler.py?line=786)                 termination_max_num_likelihood_evaluations is not None]):
    [788](file:///Users/fb90/miniforge3/envs/caustics/lib/python3.9/site-packages/jaxns/nested_sampler/nested_sampler.py?line=787)         # reduce uncertainty by half

File ~/miniforge3/envs/caustics/lib/python3.9/site-packages/jaxns/nested_sampler/nested_sampler.py:459, in NestedSampler._new_static_loop(self, init_state, num_slices, static_num_live_points, num_parallel_samplers, termination_ess, termination_evidence_uncert, termination_live_evidence_frac, termination_max_num_steps, termination_max_samples, termination_max_num_likelihood_evaluations, termination_likelihood_contour, log_L_constraint)
    [455](file:///Users/fb90/miniforge3/envs/caustics/lib/python3.9/site-packages/jaxns/nested_sampler/nested_sampler.py?line=454)     new_state = new_state._replace(done=done, termination_reason=termination_reason)
    [457](file:///Users/fb90/miniforge3/envs/caustics/lib/python3.9/site-packages/jaxns/nested_sampler/nested_sampler.py?line=456)     return new_state, live_reservoir
--> [459](file:///Users/fb90/miniforge3/envs/caustics/lib/python3.9/site-packages/jaxns/nested_sampler/nested_sampler.py?line=458) (state, live_reservoir) = while_loop(lambda body_state: jnp.bitwise_not(body_state[0].done),
    [460](file:///Users/fb90/miniforge3/envs/caustics/lib/python3.9/site-packages/jaxns/nested_sampler/nested_sampler.py?line=459)                                      body,
    [461](file:///Users/fb90/miniforge3/envs/caustics/lib/python3.9/site-packages/jaxns/nested_sampler/nested_sampler.py?line=460)                                      (init_state, init_reservoir))
    [463](file:///Users/fb90/miniforge3/envs/caustics/lib/python3.9/site-packages/jaxns/nested_sampler/nested_sampler.py?line=462) state = collect_samples(state, live_reservoir)
    [465](file:///Users/fb90/miniforge3/envs/caustics/lib/python3.9/site-packages/jaxns/nested_sampler/nested_sampler.py?line=464) new_thread_stats = _update_thread_stats(state)

    [... skipping hidden 2 frame]

File ~/miniforge3/envs/caustics/lib/python3.9/site-packages/jax/_src/lax/control_flow.py:2135, in _check_tree_and_avals(what, tree1, avals1, tree2, avals2)
   [2132](file:///Users/fb90/miniforge3/envs/caustics/lib/python3.9/site-packages/jax/_src/lax/control_flow.py?line=2131) if not all(_map(core.typematch, avals1, avals2)):
   [2133](file:///Users/fb90/miniforge3/envs/caustics/lib/python3.9/site-packages/jax/_src/lax/control_flow.py?line=2132)   diff = tree_map(_show_diff, tree_unflatten(tree1, avals1),
   [2134](file:///Users/fb90/miniforge3/envs/caustics/lib/python3.9/site-packages/jax/_src/lax/control_flow.py?line=2133)                   tree_unflatten(tree2, avals2))
-> [2135](file:///Users/fb90/miniforge3/envs/caustics/lib/python3.9/site-packages/jax/_src/lax/control_flow.py?line=2134)   raise TypeError(f"{what} must have identical types, got\n{diff}.")

TypeError: body_fun output and input must have identical types, got
(NestedSamplerState(key='ShapedArray(uint32[2])', done='ShapedArray(bool[])', sample_collection=SampleCollection(points_U='ShapedArray(float64[100200,3])', points_X={'x': 'ShapedArray(float64[100200,3])'}, log_L_samples='ShapedArray(float32[100200])', log_L_constraint='ShapedArray(float32[100200])', num_likelihood_evaluations='ShapedArray(int32[100200])', log_dZ_mean='ShapedArray(float32[100200])', log_X_mean='ShapedArray(float32[100200])', num_live_points='ShapedArray(float32[100200])', num_slices='ShapedArray(float32[100200])'), evidence_calculation=EvidenceCalculation(log_X_mean='ShapedArray(float32[])', log_X2_mean='ShapedArray(float32[])', log_Z_mean='ShapedArray(float32[])', log_ZX_mean='ShapedArray(float32[])', log_Z2_mean='ShapedArray(float32[])', log_dZ2_mean='ShapedArray(float32[])'), log_L_contour='ShapedArray(float32[])', step_idx='ShapedArray(int32[])', num_likelihood_evaluations='DIFFERENT ShapedArray(int64[]) vs. ShapedArray(int32[])', sample_idx='ShapedArray(int32[])', thread_stats=ThreadStats(evidence='ShapedArray(float64[334])', evidence_uncert='ShapedArray(float64[334])', ess='ShapedArray(float64[334])', log_L_max='ShapedArray(float64[334])', num_likelihood_evaluations='ShapedArray(float64[334])'), termination_reason='ShapedArray(int32[])'), Reservoir(points_U='ShapedArray(float64[300,3])', points_X={'x': 'ShapedArray(float64[300,3])'}, log_L_constraint='ShapedArray(float64[300])', log_L_samples='ShapedArray(float32[300])', num_likelihood_evaluations='ShapedArray(int32[300])', num_slices='ShapedArray(float64[300])')).

EDIT: I'm also seeing the same issue in the numpyro version of jaxns.

jaxns version: http://github.com/Joshuaalbert/jaxns.git@201f78a0bb2d326315d4a3772d6b5f5f534f1ceb jax version: 0.3.10 python 3.9.12

Joshuaalbert commented 2 years ago

Hi @fbartolic can you try pip install --ignore-installed jaxns==1.1.0 and see if you still get the problem. I was not able to reproduce. Also, just checking but is the config.update("jax_enable_x64", True) at the very top?

fbartolic commented 2 years ago

Also, just checking but is the config.update("jax_enable_x64", True) at the very top?

Nope, it was after the jaxns import. Placing it before the import solves the problem :). Thank you!

Joshuaalbert commented 2 years ago

I think we might be able to resolve this even if it comes after import jaxns. It looks like on num_likelihood_evaluations is having a problem, likely due to an uncasted increment somewhere. I will reopen and see if I can track it down.

Joshuaalbert commented 2 years ago

@fbartolic I pushed a quick fix, just looking for where num_likelihood_evaluations is incremented where type is not controlled and explicitly set the correct type. Maybe you could try seeing if you get the same problem with it like like,

import jaxns
config.update("jax_enable_x64", True)
Joshuaalbert commented 2 years ago

I believe this is resolved now, even when jaxns is imported above the config statement. If someone comes across it again please reopen.