Open danijar opened 3 years ago
I don't think we want autograph style evaluation of Python control flow to JAX by default for JIT compiled functions. My experience has been that parsing Python control flow can be error prone.
But this definitely seems like something that could be explored with an explicit decorator in jax.experimental
. If we're lucky, perhaps we could even reuse some of the TF autograph code.
Note also jax.experimental.loops, which exists for a similar purpose, but isn't quite as magical.
So @superbobry and a former intern actually prototyped implementing this as a library on top of JAX using (a bunch of private symbols in) TF AutoGraph. They implemented support for if -> cond
and for -> scan
(where appropriate). The code is not OSS but I can link you an internal copy since you are a Googler, feel free to ping me on chat if you want to take a look.
IIRC the feedback at the time was that this was a cool demo, but that functional control flow in JAX was actually not all that complicated (no control deps or side effecting ops to reason about) so autograph was not a big win (unlike TF) and that one of the benefits of JAX was not having these layers of indirection between users and the library. I'm sure Sergei will correct me if I am misremembering.
Thanks for the context. I think autograph functionality is a productivity win for prototyping algorithms that go beyond simple feed forward networks.
Personally, it would be hard to convince myself to give the switch from TF 2 to JAX another try if it meant I need to go back to manually writing control flow ops, which was one of the major pain points in TF 1.
In my opinion this would be most interesting for writing numerical computing primitives like ODE solvers. Right now we use a bunch of nested functions instead of for/while loops, which is rather awkward. In principle, this could let us directly JIT compile code from a library like SciPy.
So two other transformations that would be great to have would include:
while
-> while_loop
x[i] = y
-> x = x.at[i].set(y)
The prototype was joint work with @mdanatg and IIRC we were able to re-use most of the AutoGraph internals (with a few hacks here and there). I'm not sure how modular and reusable AutoGraph currently is, but perhaps Dan could comment.
The main caveat is that many symbols which were intended to be a user-facing API within AutoGraph are still private in TF. I wanted for a long time to distribute it as a separate PIP independent of TF, but never had the chance to set it up. If we had that, it would have its own API and things would be more straightforward.
AutoGraph itself is fairly reusable if you can easily adhere to the operator contract for control flow (see here, here and the tests). The interface was optimized to preserve Python semantics. So you probably need to do a bit of adaptation to fit your own control flow APIs (like we do to handle the TF control flow). IIRC, that's where we had to hack things a bit last time as well. At any rate, that will give you a clean separation where you reuse practically all non-TF bits, as the TF dependencies are concentrated in the implementations of the respective operators.
Another path that's been tried for other purposes is to transform the intermediate code generated by AutoGraph, which already transformed control flow to functional form, so it's mainly a matter of renaming functions and moving bits of AST around.
The two alternatives are dual - with AutoGraph generating things like: ag__.if_stmt(...)
, the main difference is whether you supply a custom implementation for if_stmt
or whether you transform that to lax.cond(...)
.
Hey! What's the current status of this?
An idea for a syntax that makes control flow faster to write/edit/read while keeping it clear to the user what's happening:
def cond_example(pred, true_fun, false_fun, *operands):
if lax.inline_cond(pred):
return true_fun(*operands)
else:
return false_fun(*operands)
def scan_example(f, init, xs):
carry = init
ys = []
for x in lax.inline_scan(xs):
carry, y = f(carry, x)
ys.append(y)
return carry, lax.inline_stack(ys)
def while_example(cond_fun, body_fun, init_val):
val = init_val
while lax.inline_while(cond_fun(val)):
val = body_fun(val)
return val
Without JIT compilation, the new functions lax.inline_cond()
, lax.inline_scan()
, lax.inline_stack()
, lax.inline_while()
have trivial implementations.
@mattjj does this seem like a good approach or do you see any issues or improvements?
@danijar how do you imagine implementing these helper functions? would this still need some form of AST or byte-code rewrite?
Yes exactly, I think we'll need to look at the AST like autograph does.
Maybe this should be an explicit function transformation in JAX? I'm still quite new to JAX so I don't know what the best way to make this available to users is.
The auto-graph feature in TensorFlow 2 lets
tf.function()
compile Pythonif
/for
/while
statements into a graph. This feature can be a big productivity win because one never has to manually create symbolic control flow operations.Specifically, statements whose conditionals evaluate to tensor types are compiled into symbolic control flow operations like
tf.cond()
/tf.while_loop()
. All other statements are traced so that only the selected branch of anif
/else
statement is added to the graph and so that loops are unrolled statically, and get re-traced when the Python function inputs change.Has this functionality been discussed and considered for JAX as well? It's the main reason I'm not using JAX at the moment.