rlouf / mcx

Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.
https://rlouf.github.io/mcx
Apache License 2.0
325 stars 17 forks source link

switch `import jax.numpy as np` to `import jax.numpy as jnp` #55

Closed jeremiecoullon closed 3 years ago

jeremiecoullon commented 3 years ago

This is to keep aligned with the Jax core and the documentation. From this issue: there has been a bit of a transition in the usage of JAX-flavored numpy from JAX is a drop-in replacement for numpy to JAX is a tool to use beside numpy.

rlouf commented 3 years ago

I assume you meant jnp. That's a good point, and it's better to do it before it becomes too complicated.

jeremiecoullon commented 3 years ago

I assume you meant jnp. That's a good point, and it's better to do it before it becomes too complicated.

Yes! Just changed it

rlouf commented 3 years ago

Waiting to see if the collaboration with PyMC3 and Numpyro happens before making the changes here.

rlouf commented 3 years ago

Samplers are being moved to blackjax. We will use jnp instead of np there. Closing.