aesara-devs / aehmc

An HMC/NUTS implementation in Aesara
MIT License
33 stars 6 forks source link

Kernels should have `kernel(state, *parameters)` signature #40

Closed rlouf closed 3 years ago

rlouf commented 3 years ago

We currently specialize the HMC and NUTS kernels in the factory using closures. However this is unpractical, we are moving away from this design in blackjax, see the related discussion.

The HMC kernel factory has the following signature

new_kernel(srrng, logprob_fn, step_size, inverse_mass_matrix, num_integration_steps, divergence_threshold)

and the HMC kernel:

kernel(q, log_prob, log_prob_grad)

And we suggest to instead have:

new_kernel(srng, logprob_fn, inverse_mass_matrix, num_integration_steps, divergence_threshold)
kernel(q, log_prob, log_prob_grad, step_size)

I bumped into this design issue while implementing algorithms for step size adaptation where we have to "create" as many kernels as we change the values of the parameters.

I think this issue should be addressed before we move forward with the adaptation.