Open sharadmv opened 2 years ago
Hi, for a non PL/ Compiler person, can you say a bit about what this means?
Are there any using facing implications such as being able to express impure loops at the cost of sequential execution or being able to write loops that can be compiled to parallel functional primitives, based on their effect type? (Kinda like what Dex does).
Hi, for a non PL/ Compiler person, can you say a bit about what this means?
We're exploring implementing a new control-flow primitive (for
) that generalizes scan and is more flexible/expressive. We're also exploring a "state" side-effect in JAX and its ramifications. Right now, we're containing the side-effect to just in the body of this new for
primitive. For some examples of its usage, you can see the control flow test.
Are there any using facing implications such as being able to express impure loops?
Yes, potentially. The for
implementation is still in an exploratory stage but it is a generalization of scan
because we can express more than just scan
-like patterns (see the cumsum
example in the tests).
XLA doesn't have a (general) "parallel for" like Dex does, but you're exactly right about the direction we hope to go. If we imagine we're lowering to a Dex backend, we could potentially parallelize the for loop.
Hope this is helpful!
Yes that's very informative, thanks! Sounds like some cool facilities are in the pipeline.
Can we expect some performance improvements by re-implementing scan
in terms of for
?
Can we expect some performance improvements by re-implementing scan in terms of for ?
Probably not. However, some patterns are not efficiently expressed as a scan but are as a for_loop
. In these cases, you can use for_loop
and potentially see a speedup.
@sharadmv Is it possible to define new MLIR lowering rules for primitives outside of the core JAX repo? Or: is that part of JAX user-accessible?
This is a bit of a side discussion question, so I can make a discussion if you'd like to move it over there.
This is a bit of a side discussion question, so I can make a discussion if you'd like to move it over there.
Yes, I'd prefer this discussion to happen elsewhere though I can give a brief answer here.
Is it possible to define new MLIR lowering rules for primitives outside of the core JAX repo? Or: is that part of JAX user-accessible?
Yes and no. Yes, it is possible to register lowering rules for your custom primitives via jax.interpreters.mlir.register_lowering
and no, because it is an internal API and therefore there are no promises for stability and your code could be broken at any time.
With jaxprs now supporting effect types, we can express side-effects like the State monad, where references can be read from and written to (i.e. mutation). We can use state to implement a simpler
scan
control flow primitive via afor
primitive that supports reads/writes.This issue will track the implementation progress:
get/swap/addupdate
primitivesimpl
rulesabstract_eval
rulesjvp
rulestranspose
rulesvmap
rulesfor
primitiveimpl
ruleabstract_eval
rulejvp
rulepartial_eval
transpose
vmap
rulepartial_eval_custom
rulescan
in terms offor
The "raw" version of
for
can be found here. Next steps involve porting that code to JAX core and adding tests.