Closed rlouf closed 3 years ago
Raised an issue on the JAX repo: https://github.com/google/jax/issues/4416#issue-710988308
Just to follow up this issue — the JAX team responded on the thread, in short saying it’s a deliberate change:
This error arises when the output of a jitted function is not a pytree of valid jax types, i.e. not a pytree of arrays (notice the
out_tracers = map(trace.full_raise, ans)
line in the traceback). In particular, it looks like a Python callable is being returned from a jitted function.We never intended to support returning non-jaxtype values (indeed the
jit
docstring says the decorated function's "arguments and return value should be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof."). But before 0.2.0 used to "work" accidentally.I say "work" in quotes because:
- any returned non-jaxtype values had to be constants, i.e. could not depend on the values of the jitted function arguments, and
- if those returned non-jaxtype values contained any JAX tracers, e.g. by being functions with tracers in their closure or being other kinds of non-pytree containers, those leaked tracers would cause mysterious and opaque errors downstream. In 0.2.0 we stopped supporting that never-intended and opaque-bug-prone behavior.
and then as for what to do (emphasis added):
However, due to reason (1) above, in any case where this used to work, it shouldn't be too hard to revise the code not to return the function-valued arguments, since they were constants anyway and so didn't need to be returned from the jitted function.
and they’ve closed the issue (awaiting any further questions)
Thanks @lmmx! I did not respond yet as I did not have time to look into it and did not want to pollute their issue tracker. Will reopen if needed when I get to that. In the meantime I pinned the version to the previous version so I can keep testing stuff.
I think that the problem here stems from the fact that kernel_factory
is captured by update_chain
from the enclosing scope, and JIT-compiled when update_chain
is JIT-compiled. I can't see any other explanation as kernel_factory
is not decorated by @jax.jit
.
If that is the issue the solution is simple, we have to pass kernel_factory
as an argument to update_chain
as follows (note that the parameters are constant as well here, so we can can specify this to the jit-compiler):
@functools.partial(jax.jit, static_argnums=(1,2))
def update_chain(rng_key, kernel_factory, parameters, chain_state):
kernel = self.kernel_factory(*parameters)
new_chain_state, info = kernel(rng_key, chain_state)
return new_chain_state, info
And later in the update_loop
function which advances all the chains:
@functools.partial(jax.jit, static_argnums=(2,3))
def update_loop(state, key, kernel_factory, parameters):
keys = jax.random.split(key, num_chains)
state, info = jax.vmap(kernel, in_axes=(0, None, 0, 0))(keys, kernel_factory, parameters, state)
return state, info, mcx_ravel_pytree((state, info))[0]
Which is slightly more verbose, but it would make sense to gather kernel_factory
and parameters
in a NamedTuple if this becomes too verbose.
While everything runs fine on v0.1.77, running sampling with JAX 0.2.0 returns the following error: