jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.01k stars 2.75k forks source link

Primitive to sequentially execute a function inside a vmapped function #7199

Open pl-fuchs opened 3 years ago

pl-fuchs commented 3 years ago

I am writing some longer algorithms which I like to vmap. I stumbled over some problems when combining vmap with the host_callback module and the lax.cond function:

  1. vmap of cond does not work when the branches are not side-effect-free
  2. vmap of cond of can be very ineffective if one branch is much longer but only called rarely
  3. There is no trivial batching rule for the host_callback.call function

I think a simple solution would be to implement a stop_vmap or sequential_vmap decorator. This decorator would define a batching rule, such that

@vmap
@stop_vmap
def some_fun(*args):
  # Some operations ...
  return results

would be the same as writing

def some_fun(*batched_args):

  def body_fun(*args):
    # Some operation...
    return results

  return lax.map(lambda args: body_fun(*args), batched_args)

The advantage of the decorator would be that some_fun could be used inside a much bigger vmapped function.

froystig commented 3 years ago

The "sequential_vmap" concept seems like it would be easy to implement given a more general custom batching interface, which is something that @mattjj and I have considered before.

Short of that, it might be worth considering offering the special case, by writing a single higher-order primitive whose batching rule is essentially a call to lax.map.

pharringtonp19 commented 3 years ago

@froystig Would breaking a vmap into a sequential_vmap alleviate memory usage?

froystig commented 3 years ago

The original request is about modifying the behavior of a function under vmap. If you're asking about replacing the use of vmap entirely with a sequential map, then yes: we have lax.map and using it might require less memory relatively. Did I read your question correctly?

pharringtonp19 commented 3 years ago

@froystig Apologies, I see that this was not the right place to ask that question -- You did answer my question, though, so thanks

froystig commented 3 years ago

No problem. Glad that helped!

shoyer commented 3 years ago

The "sequential_vmap" concept seems like it would be easy to implement given a more general custom batching interface, which is something that @mattjj and I have considered before.

More general "custom batching" could be interesting for use-cases like host_callback.call, where some external library that may support its own parallelism strategies. E.g., I think this could be useful for @ianwilliamson's MEEP wrapper: https://github.com/NanoComp/meep/pull/1569

Short of that, it might be worth considering offering the special case, by writing a single higher-order primitive whose batching rule is essentially a call to lax.map.

This could definitely be a good place to start, even if only as the first step towards the general solution.

froystig commented 3 years ago

This could definitely be a good place to start, even if only as the first step towards the general solution.

I tend to agree, like having written linear_call as a first step towards custom transposition. Perhaps worth trying to see what it surfaces.