jax-ml / bayeux

State of the art inference for your bayesian models.
https://jax-ml.github.io/bayeux/
Apache License 2.0
177 stars 8 forks source link

Add example for Dynamax #28

Closed theorashid closed 9 months ago

theorashid commented 9 months ago

Really basic example just to swap out the inference stage. Tried to keep it as short as possible.

google-cla[bot] commented 9 months ago

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

ColCarroll commented 9 months ago

sorry, no great way set up for me to leave comments right on the PR, but this looks great. a few quick changes:

  1. can you do the more standard seed splitting (and use new style keys!):

    seed = jax.random.key(0)
    init_key, sample_key, test_init_key, inference_key = jax.random.split(seed, 4)
  2. I'd probably leave out the advertisement to .get_kwargs, unless you feel strongly!

  3. I think you need to call

    constrained_samples = from_unconstrained(samples, param_props)

    to have your draws be back in the original space. I'm not sure why that doesn't happen in the dynamax notebook, and makes me nervous I'm wrong about that?

theorashid commented 9 months ago

You're spot on. They have from_unconstrained() within with hmc_step in their notebook

    @jit
    def hmc_step(hmc_state, step_key):
        next_hmc_state, _ = hmc_kernel(step_key, hmc_state)
        params = from_unconstrained(hmc_state.position, props)
        return next_hmc_state, params

Pushing the other changes now