dpiras / cosmopower-jax

Differentiable cosmological emulators: the JAX version of cosmopower
GNU General Public License v3.0
28 stars 3 forks source link

Setting up model with Numpyro? #9

Closed karenperezsarmiento closed 6 days ago

karenperezsarmiento commented 6 days ago

Hello, I'm trying to use Numpyro to do bayesian inference with cosmopower as the theory code, but I'm running into some issues (I'm new to Numpyro so I'm probably making a silly mistake). I'm roughly following the instructions in this blog post. I first created some mock data with cosmopower-jax and then made a model with Numpyro, but I get an error (included below). How did you implement sampling with Numpyro?

import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import camb
from cosmopower_jax.cosmopower_jax import CosmoPowerJAX as CPJ
import numpyro
from numpyro import distributions as dist, infer
import numpy as np

from datetime import date
rng_key = jax.random.key(int(date.today().strftime("%Y%m%d")))

tt_emulator = CPJ(probe='cmb_tt')
true_params = {
    "h": [0.67],
    "ln10^{10}A_s": [3.44],
    "omega_b": [0.02237],
    "omega_cdm": [0.1200],
    "n_s": [0.9649],
    "tau_reio": [0.0544],
}
tt_cl = tt_emulator.predict(true_params)
ell = tt_emulator.modes
tt_data = (2.7255*1e6)**2*ell*(ell+1)*tt_cl/(2*jnp.pi)
cinv = jnp.ones(len(ell))

def like(cinv,datavec=None):
    omega_cdm = numpyro.sample("omega_cdm", dist.Uniform(0.09, 0.15))
    omega_b = numpyro.sample("omega_b", dist.Uniform(0.017, 0.027))
    logA = numpyro.sample("logA", dist.Uniform(2.6, 3.5))
    h = numpyro.sample("h", dist.Uniform(0.6, 0.8))
    n_s = numpyro.sample("n_s", dist.Uniform(0.9, 1.1))
    tau_reio = numpyro.sample("tau_reio", dist.Uniform(0.047, 0.0617))
    sample_params = {
        "h": [h],
        "ln10^{10}A_s": [logA],
        "omega_b": [omega_b],
        "omega_cdm": [omega_cdm],
        "n_s": [n_s],
        "tau_reio": [tau_reio],
    }
    theory = tt_emulator.predict(sample_params)
    ell = tt_emulator.modes
    theory_dl = (2.7255*1e6)**2*ell*(ell+1)*theory/(2*jnp.pi)
    with numpyro.plate("data",len(datavec)):
        numpyro.sample("datavec", dist.Normal(theory_dl,cinv), obs=datavec)  

sampler = infer.MCMC(
    infer.NUTS(like),
    num_warmup=2000,
    num_samples=200000,
    num_chains=1,
    progress_bar=True,
)

sampler.run(jax.random.PRNGKey(0), cinv,datavec=tt_data)

However, I get this error:

---------------------------------------------------------------------------
TracerArrayConversionError                Traceback (most recent call last)
File <timed eval>:1

File ~/.local/lib/python3.11/site-packages/numpyro/infer/mcmc.py:682, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
    680 map_args = (rng_key, init_state, init_params)
    681 if self.num_chains == 1:
--> 682     states_flat, last_state = partial_map_fn(map_args)
    683     states = jax.tree.map(lambda x: x[jnp.newaxis, ...], states_flat)
    684 else:

File ~/.local/lib/python3.11/site-packages/numpyro/infer/mcmc.py:443, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields, remove_sites)
    441 # Check if _sample_fn is None, then we need to initialize the sampler.
    442 if init_state is None or (getattr(self.sampler, "_sample_fn", None) is None):
--> 443     new_init_state = self.sampler.init(
    444         rng_key,
    445         self.num_warmup,
    446         init_params,
    447         model_args=args,
    448         model_kwargs=kwargs,
    449     )
    450     init_state = new_init_state if init_state is None else init_state
    451 sample_fn, postprocess_fn = self._get_cached_fns()

File ~/.local/lib/python3.11/site-packages/numpyro/infer/hmc.py:749, in HMC.init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
    744 # vectorized
    745 else:
    746     rng_key, rng_key_init_model = jnp.swapaxes(
    747         vmap(random.split)(rng_key), 0, 1
    748     )
--> 749 init_params = self._init_state(
    750     rng_key_init_model, model_args, model_kwargs, init_params
    751 )
    752 if self._potential_fn and init_params is None:
    753     raise ValueError(
    754         "Valid value of `init_params` must be provided with" " `potential_fn`."
    755     )

File ~/.local/lib/python3.11/site-packages/numpyro/infer/hmc.py:693, in HMC._init_state(self, rng_key, model_args, model_kwargs, init_params)
    686 def _init_state(self, rng_key, model_args, model_kwargs, init_params):
    687     if self._model is not None:
    688         (
    689             new_init_params,
    690             potential_fn,
    691             postprocess_fn,
    692             model_trace,
--> 693         ) = initialize_model(
    694             rng_key,
    695             self._model,
    696             dynamic_args=True,
    697             init_strategy=self._init_strategy,
    698             model_args=model_args,
    699             model_kwargs=model_kwargs,
    700             forward_mode_differentiation=self._forward_mode_differentiation,
    701         )
    702         if init_params is None:
    703             init_params = new_init_params

File ~/.local/lib/python3.11/site-packages/numpyro/infer/util.py:712, in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
    710     init_strategy = _init_to_unconstrained_value(values=unconstrained_values)
    711 prototype_params = transform_fn(inv_transforms, constrained_values, invert=True)
--> 712 (init_params, pe, grad), is_valid = find_valid_initial_params(
    713     rng_key,
    714     substitute(
    715         model,
    716         data={
    717             k: site["value"]
    718             for k, site in model_trace.items()
    719             if site["type"] in ["plate"]
    720         },
    721     ),
    722     init_strategy=init_strategy,
    723     enum=has_enumerate_support,
    724     model_args=model_args,
    725     model_kwargs=model_kwargs,
    726     prototype_params=prototype_params,
    727     forward_mode_differentiation=forward_mode_differentiation,
    728     validate_grad=validate_grad,
    729 )
    731 if not_jax_tracer(is_valid):
    732     if device_get(~jnp.all(is_valid)):

File ~/.local/lib/python3.11/site-packages/numpyro/infer/util.py:446, in find_valid_initial_params(rng_key, model, init_strategy, enum, model_args, model_kwargs, prototype_params, forward_mode_differentiation, validate_grad)
    444 # Handle possible vectorization
    445 if is_prng_key(rng_key):
--> 446     (init_params, pe, z_grad), is_valid = _find_valid_params(
    447         rng_key, exit_early=True
    448     )
    449 else:
    450     (init_params, pe, z_grad), is_valid = lax.map(_find_valid_params, rng_key)

File ~/.local/lib/python3.11/site-packages/numpyro/infer/util.py:432, in find_valid_initial_params.<locals>._find_valid_params(rng_key, exit_early)
    428 init_state = (0, rng_key, (prototype_params, 0.0, prototype_grads), False)
    429 if exit_early and not_jax_tracer(rng_key):
    430     # Early return if valid params found. This is only helpful for single chain,
    431     # where we can avoid compiling body_fn in while_loop.
--> 432     _, _, (init_params, pe, z_grad), is_valid = init_state = body_fn(init_state)
    433     if not_jax_tracer(is_valid):
    434         if device_get(is_valid):

File ~/.local/lib/python3.11/site-packages/numpyro/infer/util.py:416, in find_valid_initial_params.<locals>.body_fn(state)
    414     z_grad = jacfwd(potential_fn)(params)
    415 else:
--> 416     pe, z_grad = value_and_grad(potential_fn)(params)
    417 z_grad_flat = ravel_pytree(z_grad)[0]
    418 is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat))

    [... skipping hidden 8 frame]

File ~/.local/lib/python3.11/site-packages/numpyro/infer/util.py:298, in potential_energy(model, model_args, model_kwargs, params, enum)
    294 substituted_model = substitute(
    295     model, substitute_fn=partial(_unconstrain_reparam, params)
    296 )
    297 # no param is needed for log_density computation because we already substitute
--> 298 log_joint, model_trace = log_density_(
    299     substituted_model, model_args, model_kwargs, {}
    300 )
    301 return -log_joint

File ~/.local/lib/python3.11/site-packages/numpyro/infer/util.py:69, in log_density(model, model_args, model_kwargs, params)
     57 """
     58 (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given
     59 latent values ``params``.
   (...)
     66 :return: log of joint density and a corresponding model trace
     67 """
     68 model = substitute(model, data=params)
---> 69 model_trace = trace(model).get_trace(*model_args, **model_kwargs)
     70 log_joint = jnp.zeros(())
     71 for site in model_trace.values():

File ~/.local/lib/python3.11/site-packages/numpyro/handlers.py:171, in trace.get_trace(self, *args, **kwargs)
    163 def get_trace(self, *args, **kwargs):
    164     """
    165     Run the wrapped callable and return the recorded trace.
    166 
   (...)
    169     :return: `OrderedDict` containing the execution trace.
    170     """
--> 171     self(*args, **kwargs)
    172     return self.trace

File ~/.local/lib/python3.11/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
    103     return self
    104 with self:
--> 105     return self.fn(*args, **kwargs)

File ~/.local/lib/python3.11/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
    103     return self
    104 with self:
--> 105     return self.fn(*args, **kwargs)

    [... skipping similar frames: Messenger.__call__ at line 105 (3 times)]

File ~/.local/lib/python3.11/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
    103     return self
    104 with self:
--> 105     return self.fn(*args, **kwargs)

Cell In[96], line 16, in like(cinv, datavec)
      7 tau_reio = numpyro.sample("tau_reio", dist.Uniform(0.047, 0.0617))
      8 sample_params = {
      9     "h": [h],
     10     "ln10^{10}A_s": [logA],
   (...)
     14     "tau_reio": [tau_reio],
     15 }
---> 16 theory = tt_emulator.predict(sample_params)
     17 print(theory[0])
     18 ell = tt_emulator.modes

File /gpfs/fs1/home/s/sievers/kaper/gitreps/cosmopower-jax/cosmopower_jax/cosmopower_jax.py:405, in CosmoPowerJAX.predict(self, input_vec)
    403 # convert dict to array, if needed
    404 if isinstance(input_vec, dict):
--> 405     input_vec = self._dict_to_ordered_arr_np(input_vec)  
    407 if len(input_vec.shape) == 1:
    408     input_vec = input_vec.reshape(-1, self.n_parameters)

File /gpfs/fs1/home/s/sievers/kaper/gitreps/cosmopower-jax/cosmopower_jax/cosmopower_jax.py:313, in CosmoPowerJAX._dict_to_ordered_arr_np(self, input_dict)
    300 """
    301 Sort input parameters. Takend verbatim from CP 
    302 (https://github.com/alessiospuriomancini/cosmopower/blob/main/cosmopower/cosmopower_NN.py#LL291C1-L308C73)
   (...)
    310         parameters sorted according to desired order
    311 """
    312 if self.parameters is not None:
--> 313     return np.stack([input_dict[k] for k in self.parameters], axis=1)
    314 else:
    315     return np.stack([input_dict[k] for k in input_dict], axis=1)

File ~/.local/lib/python3.11/site-packages/numpy/core/shape_base.py:443, in stack(arrays, axis, out, dtype, casting)
    372 @array_function_dispatch(_stack_dispatcher)
    373 def stack(arrays, axis=0, out=None, *, dtype=None, casting="same_kind"):
    374     """
    375     Join a sequence of arrays along a new axis.
    376 
   (...)
    441 
    442     """
--> 443     arrays = [asanyarray(arr) for arr in arrays]
    444     if not arrays:
    445         raise ValueError('need at least one array to stack')

File ~/.local/lib/python3.11/site-packages/numpy/core/shape_base.py:443, in <listcomp>(.0)
    372 @array_function_dispatch(_stack_dispatcher)
    373 def stack(arrays, axis=0, out=None, *, dtype=None, casting="same_kind"):
    374     """
    375     Join a sequence of arrays along a new axis.
    376 
   (...)
    441 
    442     """
--> 443     arrays = [asanyarray(arr) for arr in arrays]
    444     if not arrays:
    445         raise ValueError('need at least one array to stack')

File ~/.local/lib/python3.11/site-packages/jax/_src/core.py:705, in Tracer.__array__(self, *args, **kw)
    704 def __array__(self, *args, **kw):
--> 705   raise TracerArrayConversionError(self)

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float64[]
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
karenperezsarmiento commented 6 days ago

My problem was very simple to solve! I just needed to pass the parameters as an array instead of a dictionary:

def like(cinv,datavec=None):
    omega_cdm = numpyro.sample("omega_cdm", dist.Uniform(0.09, 0.15))
    omega_b = numpyro.sample("omega_b", dist.Uniform(0.017, 0.027))
    logA = numpyro.sample("logA", dist.Uniform(2.6, 3.5))
    h = numpyro.sample("h", dist.Uniform(0.6, 0.8))
    n_s = numpyro.sample("n_s", dist.Uniform(0.9, 1.1))
    tau_reio = numpyro.sample("tau_reio", dist.Uniform(0.047, 0.0617))
    sample_params = jnp.array([omega_b, omega_cdm, h, tau_reio, n_s, logA])
    theory = tt_emulator.predict(sample_params)
    ell = jnp.arange(2,2509)
    theory_dl = (2.7255*1e6)**2*ell*(ell+1)*theory/(2*jnp.pi)
    with numpyro.plate("data",len(datavec)):
        numpyro.sample("datavec", dist.Normal(theory_dl,cinv), obs=datavec)  
dpiras commented 6 days ago

Great @karenperezsarmiento, feel free to open another issue if you find other problems!