Closed theorashid closed 9 months ago
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).
View this failed invocation of the CLA check for more information.
For the most up to date status, view the checks section at the bottom of the pull request.
sorry, no great way set up for me to leave comments right on the PR, but this looks great. a few quick changes:
can you do the more standard seed splitting (and use new style keys!):
seed = jax.random.key(0)
init_key, sample_key, test_init_key, inference_key = jax.random.split(seed, 4)
I'd probably leave out the advertisement to .get_kwargs
, unless you feel strongly!
I think you need to call
constrained_samples = from_unconstrained(samples, param_props)
to have your draws be back in the original space. I'm not sure why that doesn't happen in the dynamax notebook, and makes me nervous I'm wrong about that?
You're spot on. They have from_unconstrained()
within with hmc_step in their notebook
@jit
def hmc_step(hmc_state, step_key):
next_hmc_state, _ = hmc_kernel(step_key, hmc_state)
params = from_unconstrained(hmc_state.position, props)
return next_hmc_state, params
Pushing the other changes now
Really basic example just to swap out the inference stage. Tried to keep it as short as possible.