jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.49k stars 2.8k forks source link

How to branch efficiently under jit compilation? #2065

Closed lumip closed 4 years ago

lumip commented 4 years ago

I have an algorithm that updates in a loop but with differing behavior for even and odd steps, so very simplified, I have the following

def my_func(i, *other_inputs):
  if i % 2 == 0:
    do_even_computations(*other_inputs)
  else:
    do_odd_computations(*other_inputs)

for i in range(N):
  my_func(i)

Naturally, due to the branching, this doesn't jit compile in this form, complaining about the abstraction. It's also not really feasible for me to push down the jit into my_func, because that is itself nested somewhat deep in code that should ideally be compiled itself.

Using static_argnums for my jit compilation of my_func gets rid of the error but means that my_func gets recompiled for every loop iteration, which is not really what I want either.

Working around that, I was first thinking of doing the following:

do_even_computations_jit = jax.jit(do_even_computations)
do_odd_computations_jit = jax.jit(do_odd_computations)

@functools.partial(jax.jit, static_argnums=(0,))
def my_func(i, *other_inputs):
  if i % 2 == 0:
    do_even_computations_jit(*other_inputs)
  else:
    do_odd_computations_jit(*other_inputs)

for i in range(N): # would be replaced by jax.fori_loop in real code, ofc
  my_func(i)

my_func is thus very small and compiling it would be reasonably fast - but it would still imply switching out of jax (or really, XLA) context, thus drastically reducing performance.

It would be really helpful if, since there are really only two "instances" of my function, there was some way of precompiling both of them and then having the right one chosen by the runtime depending on the condition, essentially unrolling the if.

I tried emulating that as

do_even_computations_jit = jax.jit(do_even_computations)
do_odd_computations_jit = jax.jit(do_odd_computations)

funcs = [do_even_computations_jit, do_odd_computations_jit]

@functools.partial(jax.jit, static_argnums=(0,))
def my_func(i, *other_inputs):
  return funcs[i % 2](*other_inputs)

for i in range(N): # would be replaced by jax.fori_loop in real code, ofc
  my_func(i)

however, this also fails on this high level due to being unable to index the list with traced indices.

I'm aware that it might hard to detect the equivalence classes in a generic input that would make this possible, but maybe it could be feasible to support it for branching on boolean arguments, i.e. something like:

# telling jit to compile the function for both possible values of even
@functools.partial(jax.jit, unroll_argnums=(0,))
def my_func_simpler(even: bool, *other_inputs):
  if even:
    do_even_computations(*other_inputs)
  else:
    do_odd_computations(*other_inputs)

@jax.jit
def my_func(i, *other_inputs):
  return my_func_simpler(i % 2 == 0, *other_inputs)
  # runtime could then pick the compiled variant of my_func_simpler to execute
  # based on the actual result of i % 2 == 0

for i in range(N): # would be replaced by jax.fori_loop in real code, ofc
  my_func(i)

Would this (or something like it) be possible to implement medium-term?

What would currently be the recommended way of implementing my branching function efficiently? (I have no resorted to calling both of my computation subfunctions and then multiplexing the results, but that means I'm doing superflous computations and I'm still looking for a better way.)

mattjj commented 4 years ago

Thanks for raising this, and for writing it up so beautifully.

I think I understand your idea. You probably already know this, but just to underscore it: in JAX there's no intermediate layer between pure Python and fully-staged-out-to-XLA. That is, anything that isn't fully staged out to an XLA program (like dynamic control flow deciding which XLA computations to dispatch) you can simulate just as well at the user level in pure Python, because JAX itself is pure Python. (It's an interesting marriage between super-dynamic Python and super-static super-optimizable XLA, IMO!)

That said, I can think of a couple ways you can stage this whole program out to XLA, as you'd want to for performance, though it does require slight compromises in the naturalness of your code. One thing to do is just to manually unroll an inner size-two loop over even/odd values:

# pseudocode!

def body_func(i_floordiv_2, *other_inputs):
  other_inputs = do_even_computations(2 * i_floordiv_2, *other_inputs)
  return do_odd_computations(2 * i_floordiv_2 + 1, *other_inputs)

out = lax.fori_loop(0, N // 2, body_func, *other_inputs)
if N - 1 % 2:  # might need a lax.cond here if N isn't statically known, see below
  out = do_even_computations(N - 1, *out)

(I probably got the details wrong but the spirit is in the right place, I think!)

An alternative to manual unrolling like that is to stage out control flow into your XLA program. We can't map regular Python control flow to XLA structured control flow (both as a limitation of our tracing mechanism, and also due to the complexity in mapping Python's fancy control flow to XLA structured control flow), but you can use lax.cond to stage out the conditional. That might be useful when the condition you're branching on is more complex, and can't be unrolled-out as easily as even/odd stuff. (lax.cond is not reverse-mode differentiable right now, just because we haven't gotten to it yet, but we plan to do that Real Soon Now.)

WDYT?

lumip commented 4 years ago

Thanks for the swift response and sorry for the delay on my part.

Unfortunately I cannot really manually unroll my code because my *other_inputs (are supposed to) change between call of my_func. However, using lax.cond is excellent advice and I was able to implement what I wanted with it (I managed not to be aware of that so far.. somehow I apparently glimpsed over that while reading the control flow section of the sharp bits).

Still, I believe having some way to automatically detect the more simple cases of branching and convert them into lax.cond calls would be a pretty neat feature :)

Restricted to boolean arguments and with the hypothetical unroll_argnums parameter to jax.jit, one way I could think of how this could work would be the following:

  1. trace through the function and return True whenever the unroll-marked Boolean is coerced to real Python values
  2. do the same but returning False
  3. diff the traces
  4. keep the common parts and use lax.cond for the parts that are different

Of course, that's now oversimplified and with a lot of assumptions (only bools, they have to be annotated, the must not change within the function) but I think with a bit of careful thought to address those, it could work.

mattjj commented 4 years ago

I realized I made a typo in my psuedocode, I meant to write something like:

def body_func(i_floordiv_2, *other_inputs):
  other_inputs = do_even_computations(2 * i_floordiv_2, *other_inputs)
  return do_odd_computations(2 * i_floordiv_2 + 1, *other_inputs)

Not sure if that changes anything. I edited my comment above (not to use the temp variable I had before).

one way I could think of how this could work would be the following

That's a great idea! Actually a colleague of ours (not on the JAX team) prototyped something like that, using exceptions to do the backtracking, though it was just an experiment for a day and I don't think he followed up on it. I think that kind of thing is interesting, and might be a useful tool to have in a library, but isn't the sort of thing we're likely to include in core JAX just for simplicity's sake.

If lax.cond works for you, should we close this issue?

lumip commented 4 years ago

Yea, lax.cond works. Thanks again!