Closed lumip closed 4 years ago
Thanks for raising this, and for writing it up so beautifully.
I think I understand your idea. You probably already know this, but just to underscore it: in JAX there's no intermediate layer between pure Python and fully-staged-out-to-XLA. That is, anything that isn't fully staged out to an XLA program (like dynamic control flow deciding which XLA computations to dispatch) you can simulate just as well at the user level in pure Python, because JAX itself is pure Python. (It's an interesting marriage between super-dynamic Python and super-static super-optimizable XLA, IMO!)
That said, I can think of a couple ways you can stage this whole program out to XLA, as you'd want to for performance, though it does require slight compromises in the naturalness of your code. One thing to do is just to manually unroll an inner size-two loop over even/odd values:
# pseudocode!
def body_func(i_floordiv_2, *other_inputs):
other_inputs = do_even_computations(2 * i_floordiv_2, *other_inputs)
return do_odd_computations(2 * i_floordiv_2 + 1, *other_inputs)
out = lax.fori_loop(0, N // 2, body_func, *other_inputs)
if N - 1 % 2: # might need a lax.cond here if N isn't statically known, see below
out = do_even_computations(N - 1, *out)
(I probably got the details wrong but the spirit is in the right place, I think!)
An alternative to manual unrolling like that is to stage out control flow into your XLA program. We can't map regular Python control flow to XLA structured control flow (both as a limitation of our tracing mechanism, and also due to the complexity in mapping Python's fancy control flow to XLA structured control flow), but you can use lax.cond
to stage out the conditional. That might be useful when the condition you're branching on is more complex, and can't be unrolled-out as easily as even/odd stuff. (lax.cond
is not reverse-mode differentiable right now, just because we haven't gotten to it yet, but we plan to do that Real Soon Now.)
WDYT?
Thanks for the swift response and sorry for the delay on my part.
Unfortunately I cannot really manually unroll my code because my *other_inputs
(are supposed to) change between call of my_func
. However, using lax.cond
is excellent advice and I was able to implement what I wanted with it (I managed not to be aware of that so far.. somehow I apparently glimpsed over that while reading the control flow section of the sharp bits).
Still, I believe having some way to automatically detect the more simple cases of branching and convert them into lax.cond
calls would be a pretty neat feature :)
Restricted to boolean arguments and with the hypothetical unroll_argnums
parameter to jax.jit
, one way I could think of how this could work would be the following:
True
whenever the unroll-marked Boolean is coerced to real Python valuesFalse
lax.cond
for the parts that are differentOf course, that's now oversimplified and with a lot of assumptions (only bools, they have to be annotated, the must not change within the function) but I think with a bit of careful thought to address those, it could work.
I realized I made a typo in my psuedocode, I meant to write something like:
def body_func(i_floordiv_2, *other_inputs):
other_inputs = do_even_computations(2 * i_floordiv_2, *other_inputs)
return do_odd_computations(2 * i_floordiv_2 + 1, *other_inputs)
Not sure if that changes anything. I edited my comment above (not to use the temp
variable I had before).
one way I could think of how this could work would be the following
That's a great idea! Actually a colleague of ours (not on the JAX team) prototyped something like that, using exceptions to do the backtracking, though it was just an experiment for a day and I don't think he followed up on it. I think that kind of thing is interesting, and might be a useful tool to have in a library, but isn't the sort of thing we're likely to include in core JAX just for simplicity's sake.
If lax.cond
works for you, should we close this issue?
Yea, lax.cond
works. Thanks again!
I have an algorithm that updates in a loop but with differing behavior for even and odd steps, so very simplified, I have the following
Naturally, due to the branching, this doesn't
jit
compile in this form, complaining about the abstraction. It's also not really feasible for me to push down thejit
intomy_func
, because that is itself nested somewhat deep in code that should ideally be compiled itself.Using
static_argnums
for myjit
compilation ofmy_func
gets rid of the error but means thatmy_func
gets recompiled for every loop iteration, which is not really what I want either.Working around that, I was first thinking of doing the following:
my_func
is thus very small and compiling it would be reasonably fast - but it would still imply switching out of jax (or really, XLA) context, thus drastically reducing performance.It would be really helpful if, since there are really only two "instances" of my function, there was some way of precompiling both of them and then having the right one chosen by the runtime depending on the condition, essentially unrolling the
if
.I tried emulating that as
however, this also fails on this high level due to being unable to index the list with traced indices.
I'm aware that it might hard to detect the equivalence classes in a generic input that would make this possible, but maybe it could be feasible to support it for branching on boolean arguments, i.e. something like:
Would this (or something like it) be possible to implement medium-term?
What would currently be the recommended way of implementing my branching function efficiently? (I have no resorted to calling both of my computation subfunctions and then multiplexing the results, but that means I'm doing superflous computations and I'm still looking for a better way.)