jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.99k stars 2.75k forks source link

Is it possible to have early stopping in `lax.scan`? #5642

Open salayatana66 opened 3 years ago

salayatana66 commented 3 years ago

I am considering a case in which there is a loop of operations, each one of them being expensive; one can bound the worst case number of operations with a fixed N but one can have a stopping condition and on average the loop terminates in K operations with K << N.

With lax.while_loop it is relatively easy to implement this efficiently but unfortunately one loses reverse mode differentiation. Would it be possible to have a version of lax.scan that supports such an early stopping?

shoyer commented 3 years ago

You could do this with lax.cond inside lax.scan -- just apply the identity function if the loop should terminate early.

The downside is that you will pay the price of memory allocation for all N steps. This a limitation of XLA's memory model: all memory allocation must use statistically known shapes.

mattjj commented 3 years ago

This came up in one of our chat rooms yesterday, and it was also observed that because of how we currently implement vmap-of-cond (namely we always turn it into a select), using cond for this might lead to inefficient batching.

I think adding an early stopping version of lax.scan is a good idea. It's not trivial though. Notes to self: we'd probably need to add a jaxpr for the early-stopping function, and also add 'start' and 'stop' arguments to the scan for the linearized case.

As @shoyer observed, we'd always pay the memory cost.

mattjj commented 3 years ago

Until we add this (or figure out a better plan), you might be able to work around the issue using jax.custom_vjp around your loop.

dumanah commented 2 years ago

You could do this with lax.cond inside lax.scan -- just apply the identity function if the loop should terminate early.

The downside is that you will pay the price of memory allocation for all N steps. This a limitation of XLA's memory model: all memory allocation must use statistically known shapes.

Hi, I am trying to have an early break within lax.scan since lax.while_looptakes more time to be compiled and slower. However, I did not understand how to use lax.cond exactly to achieve that. Could you elaborate more i.e giving a small example code.

epignatelli commented 3 months ago

Hey @mattjj! Late to the party, but I am facing the same limitations. Is there any plan to implement this?