Farama-Foundation / Jumpy

On-the-fly conversions between Jax and NumPy tensors
Apache License 2.0
47 stars 9 forks source link

is_jitted does not properly detect if inside a jit call (tested in version 0.4.12) #33

Open bheijden opened 1 year ago

bheijden commented 1 year ago

Since jax 0.4.12 (could be for other versions>0.4.1) the is_jitted function does not properly detect if inside a jit call.

import jumpy

def f(a):
    if jumpy.core.is_jitted():
        print("JITTED")
    else:
        print("NOT JITTED")
    return a*2

jax.jit(f)(1) --> NOT JITTED

Changing the is_jitted function to the following resolves the problem, but it is a very hacky fix. Also not sure about backward compatibility.


def is_jitted() -> bool:
    """Returns true if currently inside a jax.jit call or jit is disabled."""
    if jp.is_jax_installed is False:
        return False
    elif jax.config.jax_disable_jit:
        return True
    else:
        return len(jax._src.core.thread_local_state.trace_state.trace_stack.stack) > 1
        # return jax.core.cur_sublevel().level > 0
pseudo-rnd-thoughts commented 1 year ago

I was planning on updating the whole project to jax 0.4.X as 0.4 removes DeviceArray for Array. Therefore, I would be good to include in fix in the project when we update