jax-ml / coix

Inference Combinators in JAX
https://coix.readthedocs.io/en/latest/
Apache License 2.0
43 stars 2 forks source link

Enable custom prng #15

Closed fehiepsi closed 1 year ago

fehiepsi commented 1 year ago

This PR supports new style prng key: https://github.com/google/jax/blob/main/docs/jep/9263-typed-keys.md

The main difference are: key has type now and it has scalar shape.