Open salayatana66 opened 3 years ago
You could do this with lax.cond
inside lax.scan
-- just apply the identity function if the loop should terminate early.
The downside is that you will pay the price of memory allocation for all N
steps. This a limitation of XLA's memory model: all memory allocation must use statistically known shapes.
This came up in one of our chat rooms yesterday, and it was also observed that because of how we currently implement vmap
-of-cond
(namely we always turn it into a select
), using cond
for this might lead to inefficient batching.
I think adding an early stopping version of lax.scan
is a good idea. It's not trivial though. Notes to self: we'd probably need to add a jaxpr for the early-stopping function, and also add 'start' and 'stop' arguments to the scan for the linearized case.
As @shoyer observed, we'd always pay the memory cost.
Until we add this (or figure out a better plan), you might be able to work around the issue using jax.custom_vjp
around your loop.
You could do this with
lax.cond
insidelax.scan
-- just apply the identity function if the loop should terminate early.The downside is that you will pay the price of memory allocation for all
N
steps. This a limitation of XLA's memory model: all memory allocation must use statistically known shapes.
Hi, I am trying to have an early break within lax.scan
since lax.while_loop
takes more time to be compiled and slower. However, I did not understand how to use lax.cond
exactly to achieve that. Could you elaborate more i.e giving a small example code.
Hey @mattjj! Late to the party, but I am facing the same limitations. Is there any plan to implement this?
I am considering a case in which there is a loop of operations, each one of them being expensive; one can bound the worst case number of operations with a fixed N but one can have a stopping condition and on average the loop terminates in K operations with K << N.
With
lax.while_loop
it is relatively easy to implement this efficiently but unfortunately one loses reverse mode differentiation. Would it be possible to have a version oflax.scan
that supports such an early stopping?