This behavior is caused by jp.random_prngkey(0) which always returns a numpy array. Therefore, in the non-jit version jp_prng uses its numpy implementation, while in jit it uses the jax one (which results in different keys).
A possible solution could be to have the random_prngkey select the backend based on the seed type (int, jax.int32, np.int32). This would then lead to deterministic behavior for jit/non-jit when a key is provided with a jax type.
The random_prngkey function would need to be changed as follows:
def random_prngkey(seed: jp.int32) -> jp.ndarray:
"""Returns a PRNG key given a seed."""
if jp._which_np(seed) is jnp: # NOTE: selects backend based on seed type.
return jax.random.PRNGKey(seed)
else:
rng = onp.random.default_rng(seed)
return rng.integers(low=0, high=2**32, dtype="uint32", size=2)
I get different results from the
jp.random_split
function after the function is jit-ted.This behavior is caused by
jp.random_prngkey(0)
which always returns a numpy array. Therefore, in the non-jit versionjp_prng
uses its numpy implementation, while in jit it uses the jax one (which results in different keys).A possible solution could be to have the random_prngkey select the backend based on the seed type (int, jax.int32, np.int32). This would then lead to deterministic behavior for jit/non-jit when a key is provided with a jax type.
The random_prngkey function would need to be changed as follows: