Closed clement-bonnet closed 7 months ago
jax.random.KeyArray
jax.random.PRNGKey
jnp.isclose
==
float32
frac
int32
float64
jax.random.KeyArray
->jax.random.PRNGKey
jnp.isclose
instead of==
for assertions withfloat32
frac
variable toint32
because numpy usesfloat64
by default and Jax does not.