jax-ml / bayeux

State of the art inference for your bayesian models.
https://jax-ml.github.io/bayeux/
Apache License 2.0
177 stars 8 forks source link

PyMC + Blackjax Fail with latest version 1.2.0 #43

Closed juanitorduz closed 6 months ago

juanitorduz commented 6 months ago

Hi 👋 ! I am trying to run the PyMC example with blackjax==1.2.0 and I am getting this error (with blackjax==1.1.1 works fine) 🥲

Should I open an issue in the blackjax repo as well? I am not sure where the error is coming from.

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Cell In[1], line 24
     15     pm.Normal(
     16         "observed",
     17         avg_effect + avg_stddev * school_effects,
     18         treatment_stddevs,
     19         observed=treatment_effects,
     20     )
     22 bx_model = bx.Model.from_pymc(model)
---> 24 idata = bx_model.mcmc.blackjax_nuts(seed=jax.random.key(0))
     26 az.summary(idata)

File ~/micromamba/envs/pymc-marketing-dev/lib/python3.10/site-packages/bayeux/_src/mcmc/blackjax.py:73, in _BlackjaxSampler.__call__(self, seed, **kwargs)
     71 def __call__(self, seed, **kwargs):
     72   init_key, sample_key = jax.random.split(seed)
---> 73   kwargs = self.get_kwargs(**kwargs)
     74   initial_state = self.get_initial_state(
     75       init_key, num_chains=kwargs["extra_parameters"]["num_chains"])
     77   return _sample_blackjax(
     78       initial_state=self.inverse_transform_fn(initial_state),
     79       algorithm=_ALGORITHMS[self.algorithm],
   (...)
     82       seed=sample_key,
     83       kwargs=kwargs)

File ~/micromamba/envs/pymc-marketing-dev/lib/python3.10/site-packages/bayeux/_src/mcmc/blackjax.py:63, in _BlackjaxSampler.get_kwargs(self, **kwargs)
     61 extra_parameters = get_extra_kwargs(kwargs)
     62 constrained_log_density = self.constrained_log_density()
---> 63 adaptation_kwargs, run_kwargs = get_adaptation_kwargs(
     64     adapt_fn, algorithm, constrained_log_density, extra_parameters | kwargs)
     65 return {adapt_fn: adaptation_kwargs,
     66         "adapt.run": run_kwargs,
     67         algorithm: get_algorithm_kwargs(
     68             algorithm, constrained_log_density, kwargs),
     69         "extra_parameters": extra_parameters}

File ~/micromamba/envs/pymc-marketing-dev/lib/python3.10/site-packages/bayeux/_src/mcmc/blackjax.py:260, in get_adaptation_kwargs(adaptation_algorithm, algorithm, log_density, kwargs)
    257   adaptation_required.remove("algorithm")
    258   adaptation_kwargs["algorithm"] = algorithm
    259   adaptation_kwargs = (
--> 260       get_algorithm_kwargs(algorithm, log_density, kwargs) | adaptation_kwargs
    261   )
    263 adaptation_required = adaptation_required - adaptation_kwargs.keys()
    265 if adaptation_required:

File ~/micromamba/envs/pymc-marketing-dev/lib/python3.10/site-packages/bayeux/_src/mcmc/blackjax.py:310, in get_algorithm_kwargs(algorithm, log_density, kwargs)
    303 kwargs_with_defaults = {
    304     "logdensity_fn": log_density,
    305     "step_size": 0.5,
    306     "num_integration_steps": 16,
    307 } | kwargs
    308 shared.update_with_kwargs(
    309     algorithm_kwargs, reqd=algorithm_required, kwargs=kwargs_with_defaults)
--> 310 algorithm_required.remove("logdensity_fn")
    311 algorithm_required.discard("inverse_mass_matrix")
    312 algorithm_required.discard("alpha")

KeyError: 'logdensity_fn'
ColCarroll commented 6 months ago

Oy, I just spotted this too -- thanks for reporting! Will get a fix in today or tomorrow...

juanitorduz commented 6 months ago

No rush! I just wanted to help by reporting it :)

ColCarroll commented 6 months ago

this is all set. there's a bug this one that is causing CI to fail, but hoping to have everything back in working order soon 😬

ColCarroll commented 6 months ago

I should add that I'll cut a new release once the oryx situation in #47 is settled!