jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.48k stars 2.8k forks source link

grad + vmap + odeint AssertionError #8783

Open mattjj opened 2 years ago

mattjj commented 2 years ago

Discussed in https://github.com/google/jax/discussions/8782

Originally posted by **DanPuzzuoli** December 2, 2021 Hi, I've seen a bunch of discussions and issues surrounding this so I apologize if i'm re-raising something that has already been addressed elsewhere. I don't understand the internals of JAX enough to understand if this is some version of an issue that's already raised, though a lot of what I'm seeing is in already closed issues so I assume are solved and hence this is different. When I run the following code on the latest release versions of jax/jaxlib (a self-contained version of my actual code): ``` from jax.experimental.ode import odeint from jax import jit, value_and_grad, vmap import jax.numpy as jnp from jax.config import config config.update("jax_enable_x64", True) config.update('jax_platform_name', 'cpu') T = 1. X = -1j * jnp.array([[0., 1.], [1., 0.]], dtype=complex) Y = -1j * jnp.array([[0., -1j], [1j, 0.]], dtype=complex) def err_obj(a, b_vals): def err(b): def rhs(y, t): return (b * X + a * (t**2) * Y) @ y res = odeint(rhs, y0=jnp.eye(2, dtype=complex), t=jnp.array([0, T], dtype=complex), rtol=1e-6, atol=1e-6) return jnp.abs((X * res[-1]).sum())**2 / 4 all_err = vmap(err)(b_vals) return all_err.sum() b_vals = jnp.array([1., 2., 3., 4., 5.]) jit(value_and_grad(lambda a: err_obj(a, b_vals)))(1.) ``` I get the error: ``` AssertionError: length mismatch: [1, 4] ``` I've definitely reverse-mode differentiated this code in the past with success, though it was some time ago so would have been on a much older version of jax/jaxlib. Based on other similar issues I've seen it seems like this has something to do with the interaction of reverse-mode autodiff, vmap, and the control flow used in `odeint`, but again, the issues I've seen raising these kinds of errors seem to have been solved? Thanks!
mattjj commented 2 years ago

I'm not sure, but I think @danieldjohnson told me about this error a long time ago (maybe back in May?), and shared this repro which hits the same thing:

import functools

import jax
import jax.numpy as jnp

@functools.partial(jax.custom_vjp, nondiff_argnums=(0,))
def foo_custom_vjp(kernel_closure, aux_args):
  return kernel_closure(*aux_args)

def foo_fwd(kernel_closure, aux_args):
  out = foo_custom_vjp(kernel_closure, aux_args)
  # Output must be saved for the error to occur.
  return out, (out,)

def foo_bwd(kernel_closure, saved, table_bar):
  return ([jnp.zeros([8, 10])],)
  raise NotImplementedError("JAX problem occurs before backward pass executes")

foo_custom_vjp.defvjp(foo_fwd, foo_bwd)

def repro(differentiable_arg, batched_arg):

  def kernel_closure(differentiable_arg):
    # This function closes over batched_arg, which is a batched NDArray;
    # it then gets nondiff-argnumed into foo_custom_vjp
    return jnp.exp(differentiable_arg + batched_arg)

  result = foo_custom_vjp(kernel_closure, [differentiable_arg])
  return jnp.mean(result)

def batched_repro(soft_prediction, samples, batch_size):
  def go(sample):
    return repro(soft_prediction, sample)
  return jnp.mean(jax.vmap(go)(samples))

@functools.partial(jax.jit)
def trigger(soft_prediction, samples):
  grads = jax.grad(batched_repro)(soft_prediction, samples, batch_size=3)
  return grads

trigger(jnp.zeros([8, 10]), jnp.zeros([7, 10], dtype=jnp.int32))

I vaguely recall that:

  1. I tried to fix it, and at least figured out the problem, but then
  2. decided that I would instead land no-more-post-process soon and that would define the issue away.

Then I failed to land that branch... and now I've forgotten everything I figured out before!

mattjj commented 2 years ago

Okay, I'm starting to get the issue again... I'm going to jot some notes here but they probably won't be comprehensible without detailed knowledge of JAX internals.

These examples exercise BatchTrace.post_process_custom_vjp_call, which is currently identical to BatchTrace.post_process_custom_jvp_call. But actually that implementation is broken!

The problem is that the number of outputs to this custom_vjp_call change between when we run BatchTrace.post_process_custom_jvp_call and when we run the todo inside it. That is, vals in the outer lexical scope (that of the body of BatchTrace.post_process_custom_jvp_call) might be length 2 while vals inside the body of todo might be length 1. But we're using the same dims! In this particular case, the number of outputs is changing because we're packaging a primal-and-tangent pair up into a single JVPTracer.


Backing up, the reason this is happening, and the reason post_process_call methods exist at all, is that the JAX tracing machinery speculatively assumes that the only transformations that apply to a primitive are those which have Tracers boxing some arguments. We assume that even for call primitives (like the one underlying custom_vjp). But for call primitives that assumption can be broken when a the function-valued argument closes over Tracers of Traces not represented on the arguments. In that case, we can end up with outputs that are boxed in more Tracer levels than we expected given the arguments!

The solution is to unwrap and re-wrap output Tracers, and call back into the Traces we forgot to process so they're aware that a call primitive was bound. That's exactly what the post_process_call methods do for call primitives: when the core system calls them (when we're about to return from the function called by the higher-order primitive), they dutifully unwrap their Tracers and return a todo callback so that the core can put them back on the right order (after the higher-order primitive has returned).


Coming back to this particular issue, starting in the case of custom_jvp (even though the repros above are for custom_vjp), the values returned by the called function are either just primals or are a flattened list of primal and tangent pairs (guaranteed to be twice the length of the just-primals version). In the latter case, dims will be of length 2N (and of structure [*dims_, *dims_]) when post_process_custom_jvp_call is first called, but then of length N (and of structure dims_) when the todo is called. (In the former case there's no change in dims between the two stages.)

We should plumb enough information into this function so that it knows what's going to happen and so that it can check and manipulate dims as needed. That information is available, e.g. in this suppressed _ value in the caller of process_env_traces, though I'll have to think about how best to plumb it...

mattjj commented 2 years ago

One last thing to note: in the custom_jvp example there's this 2N-or-N relationship, but for custom_vjp the outputs in question are the primals and the residuals, so there's no relationship in general. I think we have to plumb in how many residuals there are.

DanPuzzuoli commented 2 years ago

Not sure if this is the place for it, but can you explain what

But for call primitives that assumption can be broken when a the function-valued argument closes over Tracers of Traces not represented on the arguments.

means? Specifically: what does it mean for a function to "close over Tracers". I've seen this statement appear in a bunch of places (while looking up this particular issue). Even without understanding the JAX internals much it seems like something I might be able to recognize in the future.

mattjj commented 2 years ago

Sure! Apologies for the jargon.

Here's an example program, without any JAX for now:

def f(x):
  def g(y):
    return x * y
  return g(3.)

f(2.)

Here I'd say the inner function g closes over the variable x: that variable occurs in the body of g while being bound not as a parameter/argument of g but instead in a "lexically enclosing scope" (meaning here the scope introduced by f, which "lexically encloses" g in the sense that the text defining g is literally inside the body of f).

This is interesting because in some sense the value of x is an input to an application of g, even though it's not in g's parameters/arguments. So one could say there are actually two kinds of inputs to a Python function: arguments and closed-over values.

Let's add JAX to the example:

def f(x):
  @jax.jit
  def g(y):
    return x * y
  return g(3.)

jax.grad(f)(2.)

Here again the function g closes over x. The difference is that now JAX's tracing mechanism is at work. To evaluate jax.grad(f)(2.), we box up the value of 2. in a Tracer and then use it to trace (i.e. monitor) the operations that are applied to it to produce the output of f. So in the body of f, x will refer to a Tracer instance.

But notice that when we call the jit-decorated function in evaluating g(3.), that Tracer doesn't appear in the arguments to g. Yet it'll affect the output of g! In this case I'd say that the function g closes over a Tracer when evaluating jax.grad(f)(2.).

Because the Tracers that a Python callable closes over can't easily be inspected ahead-of-time, you might only find out what Tracers are in a function's closure (and hence which transformations must be applied to the function) after running it. You might imagine that makes tracing a bit tricky.

The general mechanism which sorts out these issues is the thing I alluded to in above comments. And it's buggy for custom_vjp when closed-over vmap tracers are involved.

WDYT?

DanPuzzuoli commented 2 years ago

Ah yeah, makes sense! Thanks for the clarification. I saw in a previous issue you'd recommended swapping the order of grad and vmap to get around this, would you still recommend doing this here?

mattjj commented 2 years ago

I'm not sure. But this is a bad bug and I plan to fix this in the next few days...

mattjj commented 2 years ago

Okay, let's try to crush this bug!

To start, here's a repro of the jax.custom_jvp version:

import jax
import jax.numpy as jnp

def h(z):
  def f(x):
    @jax.custom_jvp
    def g(y):
      return x * y

    # NOTE: rule closes over vmap tracer
    @g.defjvp
    def g_jvp(primals, tangents):
      (y,), (ydot,) = primals, tangents
      return x * y, x * ydot

    return g(z)  # NOTE: no vmapped arg

  return jax.vmap(f)(jnp.arange(3.))

jax.jvp(h, (1.,), (2.,))
mattjj commented 2 years ago

And here's a corresponding jax.custom_vjp minimal repro:

import jax
import jax.numpy as jnp

def h(z):
  def f(x):
    @jax.custom_vjp
    def g(y):
      return x * y

    def g_fwd(y):
      return x * y, (x * y, y)
    def g_rev(xys, w_bar):
      xy, _ = xys
      return (xy * w_bar,)
    g.defvjp(g_fwd, g_rev)

    return g(z)

  return jax.vmap(f)(jnp.arange(3.)).sum()

jax.grad(h)(1.)

I have a fix, in #8915, but I need to look it over and decide if there are things to clean up, and whether to break it into multiple PRs (because there were a couple cleanups I did which made the fix easier but could be made independent).

mattjj commented 2 years ago

Actually I just noticed that #8915 fixes my repro as well as the code in this above comment, but the code in the OP runs into a different error. I think it's a separate bug for me to fix...

mattjj commented 2 years ago

8915 finally went in, but IIRC last I tried it the repro in the OP now runs into a distinct second issue.

mattjj commented 2 years ago

Ah! The distinct second issue is just that odeint has a check for floating point time values, and it raises an error if it's given complex values (as in this example).

But if we change the repro to pass t=jnp.array([0, T], dtype=complex), we get yet another issue, to do with a leaked vmap tracer...