Joshuaalbert / jaxns

Probabilistic Programming and Nested sampling in JAX
https://jaxns.readthedocs.io/
Other
141 stars 9 forks source link

Can't get "Gaussian processes with outliers" example to work #51

Closed mvsoom closed 2 years ago

mvsoom commented 2 years ago

I just installed jaxns with pip install jaxns and running Python 3.8.10.

I'm trying to run https://github.com/Joshuaalbert/jaxns/blob/master/examples/gaussian_processes/gaussian_process_marginalisation.ipynb.

I'm getting a lot of ImportErrors when I run the following cell:

from jaxns import NestedSampler, PriorChain, UniformPrior, HalfLaplacePrior, GaussianProcessKernelPrior
from jaxns import plot_cornerplot, plot_diagnostics
from jaxns.modules.gaussian_process.kernels import RBF, M12, M32
from jaxns import marginalise_dynamic, summary

This works after changing this to

from jaxns import NestedSampler
from jaxns.prior_transforms import PriorChain, UniformPrior, HalfLaplacePrior, GaussianProcessKernelPrior
from jaxns import plot_cornerplot, plot_diagnostics
from jaxns.gaussian_process.kernels import RBF, M12, M32
from jaxns.utils import marginalise_dynamic, summary

However, when I run run_for_kernel(RBF()), I get the following error:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-8-6069bfe93f17> in <module>
      1 # Let us compare these models.
      2 
----> 3 logZ_rbf, logZerr_rbf = run_for_kernel(RBF())
      4 logZ_m12, logZerr_m12 = run_for_kernel(M12())
      5 logZ_m32, logZerr_m32 = run_for_kernel(M32())

<ipython-input-7-6b9bcd99eef1> in run_for_kernel(kernel)
     44         return jnp.diag(K - K @ jnp.linalg.solve(K + data_cov, K))
     45 
---> 46     with PriorChain() as prior_chain:
     47         l = UniformPrior('l', 0., 2.)
     48         uncert = HalfLaplacePrior('uncert', 1.)

AttributeError: __enter__

I tried looking around in the source code and in the repository history for a fix, but found nothing. How can I fix this, or what is the equivalent call to PriorChain() without the context?

p.s.: The bleeding-edge pip install git+http://github.com/Joshuaalbert/jaxns.git version gives an unrelated import error:

---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
<ipython-input-5-207ff624d917> in <module>
----> 1 from jaxns import NestedSampler
      2 from jaxns.prior_transforms import PriorChain, UniformPrior, HalfLaplacePrior, GaussianProcessKernelPrior
      3 from jaxns import plot_cornerplot, plot_diagnostics
      4 from jaxns.gaussian_process.kernels import RBF, M12, M32
      5 from jaxns.utils import marginalise_dynamic, summary

~/.local/lib/python3.8/site-packages/jaxns/__init__.py in <module>
      2 logging.basicConfig(format='%(levelname)s[%(asctime)s]: %(message)s', level=logging.INFO)
      3 
----> 4 from jaxns.nested_sampler import *
      5 from jaxns.optimisation import *
      6 from jaxns.prior_transforms import *

~/.local/lib/python3.8/site-packages/jaxns/nested_sampler/__init__.py in <module>
----> 1 from jaxns.nested_sampler.nested_sampler import NestedSampler

~/.local/lib/python3.8/site-packages/jaxns/nested_sampler/nested_sampler.py in <module>
      9 from jaxns.internals.maps import chunked_pmap, replace_index, get_index, prepare_func_args
     10 from jaxns.internals.stats import linear_to_log_stats, effective_sample_size
---> 11 from jaxns.nested_sampler.nested_sampling import build_get_sample, get_seed_goal, \
     12     collect_samples, compute_evidence, _update_thread_stats
     13 from jaxns.nested_sampler.utils import summary

~/.local/lib/python3.8/site-packages/jaxns/nested_sampler/nested_sampling.py in <module>
      2 
      3 from jax import numpy as jnp, random, tree_map, value_and_grad
----> 4 from jax._src.lax.lax import dynamic_update_slice
      5 from jax.lax import while_loop
      6 from jaxns.internals.log_semiring import LogSpace

ImportError: cannot import name 'dynamic_update_slice' from 'jax._src.lax.lax' (/home/marnix/.local/lib/python3.8/site-packages/jax/_src/lax/lax.py)
Joshuaalbert commented 2 years ago

This is because jaxns is in release candidate for 1.0.0 and the examples require 1.0.0. I'm about 99% happy with the code, and will push to pip today I think. This addresses the dynamic_update_slice problem too which is because jax updated their API.

mvsoom commented 2 years ago

That's great! Thanks.