blackjax-devs / blackjax

BlackJAX is a Bayesian Inference library designed for ease of use, speed and modularity.
https://blackjax-devs.github.io/blackjax/
Apache License 2.0
806 stars 105 forks source link

New jax keys #703

Closed reubenharry closed 3 months ago

reubenharry commented 3 months ago

Current behavior

We currently use jax.random.PRNGKey throughout.

Desired behavior

Following this jax discussion https://jax.readthedocs.io/en/latest/jep/9263-typed-keys.html, I'm wondering if we want to switch to the new key style, jax.random.key

junpenglao commented 3 months ago

We are already using the new random key: https://github.com/blackjax-devs/blackjax/pull/569. It is just we still call the custom type PRNGKey