rlouf / mcx

Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.
https://rlouf.github.io/mcx
Apache License 2.0
325 stars 17 forks source link

Replace `jax.numpy` abbreviation from `np` to `jnp` #72

Closed sidravi1 closed 3 years ago

sidravi1 commented 3 years ago

jax and numpyro both use jnp for jax.numpy and np for numpy. To avoid confusion, we should switch to their convention:

Need to go through the entire codebase, change these lines:

import jax.numpy as np -> import jax.numpy as jnp import numpy as onp -> import numpy as np

and replace all references of np with jnp and onp with np.

kancurochat commented 3 years ago

Hi! I'm going to work on this. Could you please assign me this issue? Thanks beforehand :)

rlouf commented 3 years ago

Fixed in #73, that was fast!