Farama-Foundation / Jumpy

On-the-fly conversions between Jax and NumPy tensors
Apache License 2.0
45 stars 9 forks source link

Non-deterministic behavior jit vs non-jit #29

Closed bheijden closed 1 year ago

bheijden commented 1 year ago

I get different results from the jp.random_split function after the function is jit-ted.

import jumpy as jp
import jax
import numpy as onp

def jp_prng(rng):
    return jp.random_split(rng)   # jit --> uses jax, non-jit --> uses numpy.

def jax_prng(rng):
    return jax.random.split(rng)

jp_prng_jit = jax.jit(jp_prng)
jax_prng_jit = jax.jit(jax_prng)

seed = jp.random_prngkey(0)  # always returns a numpy array (regardless of input type).

jax_key_jit = jax_prng_jit(seed)
jax_key = jax_prng(seed)

jp_key_jit = jp_prng_jit(seed)
jp_key = jp_prng(seed)

print(f"jax: {onp.isclose(jax_key_jit, jax_key).all()}")   # --> jax: True
print(f"jumpy: {onp.isclose(jp_key, jp_key_jit).all()}")  # --> jumpy: False

This behavior is caused by jp.random_prngkey(0) which always returns a numpy array. Therefore, in the non-jit version jp_prng uses its numpy implementation, while in jit it uses the jax one (which results in different keys).

A possible solution could be to have the random_prngkey select the backend based on the seed type (int, jax.int32, np.int32). This would then lead to deterministic behavior for jit/non-jit when a key is provided with a jax type.

The random_prngkey function would need to be changed as follows:


def random_prngkey(seed: jp.int32) -> jp.ndarray:
    """Returns a PRNG key given a seed."""
    if jp._which_np(seed) is jnp:   # NOTE: selects backend based on seed type.
        return jax.random.PRNGKey(seed)
    else:
        rng = onp.random.default_rng(seed)
        return rng.integers(low=0, high=2**32, dtype="uint32", size=2)