rlouf / mcx

Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.
https://rlouf.github.io/mcx
Apache License 2.0
324 stars 17 forks source link

Sampling breaks on JAX 0.2.0 #36

Closed rlouf closed 3 years ago

rlouf commented 4 years ago

While everything runs fine on v0.1.77, running sampling with JAX 0.2.0 returns the following error:

    @jax.jit
    def update_chains(rng_key, parameters, chain_state):
>       kernel = self.kernel_factory(*parameters)
E       jax.traceback_util.FilteredStackTrace: TypeError: <class 'function'> is not a valid JAX type
E       
E       The stack trace above excludes JAX-internal frames.
E       The following is the original exception that occurred, unmodified.
E       
E       --------------------

mcx/sampling.py:207: FilteredStackTrace

The above exception was the direct cause of the following exception:

mcx/sampling.py:245: in run
    keys, self.parameters, state
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/traceback_util.py:137: in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/api.py:1220: in batched_fun
    axis_name=axis_name)
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/interpreters/batching.py:36: in batch
    return batched_fun.call_wrapped(*in_vals)
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/linear_util.py:151: in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/traceback_util.py:137: in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/api.py:215: in f_jitted
    donated_invars=donated_invars)
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/core.py:1144: in bind
    return call_bind(self, fun, *args, **params)
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/core.py:1135: in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/core.py:1147: in process
    return trace.process_call(self, fun, tracers, params)
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/interpreters/batching.py:171: in process_call
    vals_out = call_primitive.bind(f, *vals, **params)
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/core.py:1144: in bind
    return call_bind(self, fun, *args, **params)
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/core.py:1135: in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/core.py:1147: in process
    return trace.process_call(self, fun, tracers, params)
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/interpreters/partial_eval.py:940: in process_call
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(f, self.main, in_avals)
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/interpreters/partial_eval.py:1005: in trace_to_subjaxpr_dynamic
    out_tracers = map(trace.full_raise, ans)
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/util.py:35: in safe_map
    return list(map(f, *args))
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/core.py:358: in full_raise
    return self.pure(val)
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/interpreters/partial_eval.py:897: in new_const
    aval = raise_to_shaped(get_aval(val), weak_type=dtypes.is_python_scalar(val))
../../../.virtualenvs/mcx/lib/python3.7/site-packages/jax/core.py:821: in get_aval
    return concrete_aval(x)
rlouf commented 4 years ago

Raised an issue on the JAX repo: https://github.com/google/jax/issues/4416#issue-710988308

lmmx commented 4 years ago

Just to follow up this issue — the JAX team responded on the thread, in short saying it’s a deliberate change:

This error arises when the output of a jitted function is not a pytree of valid jax types, i.e. not a pytree of arrays (notice the out_tracers = map(trace.full_raise, ans) line in the traceback). In particular, it looks like a Python callable is being returned from a jitted function.

We never intended to support returning non-jaxtype values (indeed the jit docstring says the decorated function's "arguments and return value should be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof."). But before 0.2.0 used to "work" accidentally.

I say "work" in quotes because:

  1. any returned non-jaxtype values had to be constants, i.e. could not depend on the values of the jitted function arguments, and
  2. if those returned non-jaxtype values contained any JAX tracers, e.g. by being functions with tracers in their closure or being other kinds of non-pytree containers, those leaked tracers would cause mysterious and opaque errors downstream. In 0.2.0 we stopped supporting that never-intended and opaque-bug-prone behavior.

and then as for what to do (emphasis added):

However, due to reason (1) above, in any case where this used to work, it shouldn't be too hard to revise the code not to return the function-valued arguments, since they were constants anyway and so didn't need to be returned from the jitted function.

and they’ve closed the issue (awaiting any further questions)

rlouf commented 4 years ago

Thanks @lmmx! I did not respond yet as I did not have time to look into it and did not want to pollute their issue tracker. Will reopen if needed when I get to that. In the meantime I pinned the version to the previous version so I can keep testing stuff.

rlouf commented 3 years ago

I think that the problem here stems from the fact that kernel_factory is captured by update_chain from the enclosing scope, and JIT-compiled when update_chain is JIT-compiled. I can't see any other explanation as kernel_factory is not decorated by @jax.jit.

If that is the issue the solution is simple, we have to pass kernel_factory as an argument to update_chain as follows (note that the parameters are constant as well here, so we can can specify this to the jit-compiler):

@functools.partial(jax.jit, static_argnums=(1,2))
def update_chain(rng_key, kernel_factory, parameters, chain_state):
    kernel = self.kernel_factory(*parameters)
    new_chain_state, info = kernel(rng_key, chain_state)
    return new_chain_state, info

And later in the update_loop function which advances all the chains:

@functools.partial(jax.jit, static_argnums=(2,3))
def update_loop(state, key, kernel_factory, parameters):
    keys = jax.random.split(key, num_chains)
    state, info = jax.vmap(kernel, in_axes=(0, None, 0, 0))(keys, kernel_factory, parameters, state)
    return state, info, mcx_ravel_pytree((state, info))[0]

Which is slightly more verbose, but it would make sense to gather kernel_factory and parameters in a NamedTuple if this becomes too verbose.