jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.27k stars 2.78k forks source link

Tracker: decomposing scan (aka "Five Loop") #10982

Open sharadmv opened 2 years ago

sharadmv commented 2 years ago

With jaxprs now supporting effect types, we can express side-effects like the State monad, where references can be read from and written to (i.e. mutation). We can use state to implement a simpler scan control flow primitive via a for primitive that supports reads/writes.

This issue will track the implementation progress:

The "raw" version of for can be found here. Next steps involve porting that code to JAX core and adding tests.

AriMKatz commented 2 years ago

Hi, for a non PL/ Compiler person, can you say a bit about what this means?

Are there any using facing implications such as being able to express impure loops at the cost of sequential execution or being able to write loops that can be compiled to parallel functional primitives, based on their effect type? (Kinda like what Dex does).

sharadmv commented 2 years ago

Hi, for a non PL/ Compiler person, can you say a bit about what this means?

We're exploring implementing a new control-flow primitive (for) that generalizes scan and is more flexible/expressive. We're also exploring a "state" side-effect in JAX and its ramifications. Right now, we're containing the side-effect to just in the body of this new for primitive. For some examples of its usage, you can see the control flow test.

Are there any using facing implications such as being able to express impure loops?

Yes, potentially. The for implementation is still in an exploratory stage but it is a generalization of scan because we can express more than just scan-like patterns (see the cumsum example in the tests).

XLA doesn't have a (general) "parallel for" like Dex does, but you're exactly right about the direction we hope to go. If we imagine we're lowering to a Dex backend, we could potentially parallelize the for loop.

Hope this is helpful!

AriMKatz commented 2 years ago

Yes that's very informative, thanks! Sounds like some cool facilities are in the pipeline.

LoicRaillon commented 2 years ago

Can we expect some performance improvements by re-implementing scan in terms of for ?

sharadmv commented 2 years ago

Can we expect some performance improvements by re-implementing scan in terms of for ?

Probably not. However, some patterns are not efficiently expressed as a scan but are as a for_loop. In these cases, you can use for_loop and potentially see a speedup.

femtomc commented 2 years ago

@sharadmv Is it possible to define new MLIR lowering rules for primitives outside of the core JAX repo? Or: is that part of JAX user-accessible?

This is a bit of a side discussion question, so I can make a discussion if you'd like to move it over there.

sharadmv commented 2 years ago

This is a bit of a side discussion question, so I can make a discussion if you'd like to move it over there.

Yes, I'd prefer this discussion to happen elsewhere though I can give a brief answer here.

Is it possible to define new MLIR lowering rules for primitives outside of the core JAX repo? Or: is that part of JAX user-accessible?

Yes and no. Yes, it is possible to register lowering rules for your custom primitives via jax.interpreters.mlir.register_lowering and no, because it is an internal API and therefore there are no promises for stability and your code could be broken at any time.