jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.49k stars 2.8k forks source link

What does vmap(odeint) do? #2586

Closed samuela closed 4 years ago

samuela commented 4 years ago

JAX on master (1bb9aaa88c09a90a115c0304c05b0fc25932523b) is perfectly content letting one write

vmap(lambda y0: odeint(f, y0, timesteps))(y0_vec)

But it's not immediately clear to me what this means... Does this correspond to "stacking" the ODEs into one large one and then running the solver on that? That approach would have subtle behavior for adaptive time step choices and error tolerances. Or does it mean running them each in parallel but with their own adaptive time stepping? Or something else entirely?

mattjj commented 4 years ago

Good question! It actually reduces to how vmap-of-while works, since there's nothing else going on here peculiar to odeint. And vmap-of-while works by executing a single loop where the predicate function is essentially

def new_cond_fun(carry):
  return vmap(orig_cond_fun)(carry).any()

and the new body fun is effectively

def new_body_fun(carry):
  keep = vmap(orig_cond_fun)(carry)
  new_carry = vmap(orig_body_fun)(carry)
  return np.where(keep, new_carry, carry)

except of course pytrees are handled correctly, unlike in this impressionistic sketch.

In words, we vectorize a loop by vectorizing the cond_fun and body_fun, then making sure we keep running the loop until all the constituent loops have terminated, and we throw away the updates for loops that have already terminated via masking.

Actually it's a bit smarter than that: for example, instead of batching up all the elements of the loop carry tuple, it only batches the ones that are necessary. And if the predicate function doesn't need to be batched, it's not, and so there's no np.where needed either. (Here's the while_loop batching rule which decides the minimal amount of stuff to batch, and here's the while_loop translation rule which generates the batched code with the "any" and "select" logic. One way to improve the lowering might be to limit the batching size based on the amount of SIMD parallelism available in our target machine. That'd be better for things like rejection samplers too.)

This plays out automatically for odeint, where none of the step size adaptations affect each other or anything like that. They execute just like they would have without being vectorized, except for some extra work that is thrown away (and is free on a perfectly parallel SIMD machine). And of course you can mix in as much reverse-mode autodiff as you want too.

Pretty cool, right?

samuela commented 4 years ago

Wow, that's even more impressive than I was expecting! Thank you so much for such a detailed explanation!