Joshuaalbert / jaxns

Probabilistic Programming and Nested sampling in JAX
https://jaxns.readthedocs.io/
Other
135 stars 9 forks source link

save_results TypeError from toy model. #171

Closed nenasedk closed 2 months ago

nenasedk commented 2 months ago

Describe the bug Trying to run a toy model based on the multivariate likelihood example (https://jaxns.readthedocs.io/en/latest/examples/mvn_data_mvn_prior.html). The sampler runs fine, and I can convert the state into a results, make diagnostic plots etc. However, when running the save_results function, I get the following error:

TypeError: Object of type ArrayImpl is not JSON serializable

Expected behavior That the results should be saved to a .json or .npz file.

Observed behavior Full traceback:

Traceback (most recent call last):
  File "/Users/nasedkin/python-packages/petitRADTRANS/petitRADTRANS/retrieval/jax_toy_example.py", line 72, in <module>
    save_results(results,'jaxns_toy_model_results.npz')
  File "/Users/nasedkin/anaconda3/envs/prt3python312/lib/python3.12/site-packages/jaxns/utils.py", line 543, in save_results
    save_pytree(results, save_file)
  File "/Users/nasedkin/anaconda3/envs/prt3python312/lib/python3.12/site-packages/jaxns/utils.py", line 529, in save_pytree
    json.dump(serialise_namedtuple(pytree), fp, indent=2)
  File "/Users/nasedkin/anaconda3/envs/prt3python312/lib/python3.12/json/__init__.py", line 179, in dump
    for chunk in iterable:
  File "/Users/nasedkin/anaconda3/envs/prt3python312/lib/python3.12/json/encoder.py", line 432, in _iterencode
    yield from _iterencode_dict(o, _current_indent_level)
  File "/Users/nasedkin/anaconda3/envs/prt3python312/lib/python3.12/json/encoder.py", line 406, in _iterencode_dict
    yield from chunks
  File "/Users/nasedkin/anaconda3/envs/prt3python312/lib/python3.12/json/encoder.py", line 406, in _iterencode_dict
    yield from chunks
  File "/Users/nasedkin/anaconda3/envs/prt3python312/lib/python3.12/json/encoder.py", line 439, in _iterencode
    o = _default(o)
        ^^^^^^^^^^^
  File "/Users/nasedkin/anaconda3/envs/prt3python312/lib/python3.12/json/encoder.py", line 180, in default
    raise TypeError(f'Object of type {o.__class__.__name__} '
TypeError: Object of type ArrayImpl is not JSON serializable

Minimal Verifiable Complete Example log_likelihood and prior_model are as in the multivariate likelihood example.

model = Model(prior_model=prior_model,
              log_likelihood=log_likelihood)
ns = DefaultNestedSampler(
    model=model,
    max_samples=1e8,
    num_live_points = 400,
    parameter_estimation = True,
    verbose=True)

termination_reason, state = jax.jit(ns)(random.PRNGKey(42654))
results = ns.to_results(termination_reason=termination_reason, state=state)
save_results(results,'jaxns_toy_model_results.npz') #or .json

JAXNS version Output of pip freeze | grep jaxns: jaxns==2.5.0

Joshuaalbert commented 2 months ago

Thanks for pointing this out. I made a fix for it, and will push to 2.5.2 along with some other minor improvements.

Joshuaalbert commented 2 months ago

For temporary fix do this before saving:

import jax
import numpy as np

results = jax.tree.map(np.asarray, results)
Joshuaalbert commented 2 months ago

Also, I notice you're using 2.5.0, due to #168 you should upgrade to >=2.5.1