Open bheijden opened 1 year ago
Since jax 0.4.12 (could be for other versions>0.4.1) the is_jitted function does not properly detect if inside a jit call.
is_jitted
import jumpy def f(a): if jumpy.core.is_jitted(): print("JITTED") else: print("NOT JITTED") return a*2 jax.jit(f)(1) --> NOT JITTED
Changing the is_jitted function to the following resolves the problem, but it is a very hacky fix. Also not sure about backward compatibility.
def is_jitted() -> bool: """Returns true if currently inside a jax.jit call or jit is disabled.""" if jp.is_jax_installed is False: return False elif jax.config.jax_disable_jit: return True else: return len(jax._src.core.thread_local_state.trace_state.trace_stack.stack) > 1 # return jax.core.cur_sublevel().level > 0
I was planning on updating the whole project to jax 0.4.X as 0.4 removes DeviceArray for Array. Therefore, I would be good to include in fix in the project when we update
DeviceArray
Array
Since jax 0.4.12 (could be for other versions>0.4.1) the
is_jitted
function does not properly detect if inside a jit call.Changing the
is_jitted
function to the following resolves the problem, but it is a very hacky fix. Also not sure about backward compatibility.