Closed nrontsis closed 3 years ago
This is an intentional requirement of scan
. The constant length is what enables support of reverse-mode autodiff that still compiles via jit
. Variable-length loops are available as lax.while_loop
and lax.fori_loop
. These will jit
and support forward-mode autodiff.
Thanks for the reply.
Are there any plans on supporting "variable-length" scans up to a max number of iterations? I am tempted to ask due to this comment from @mattjj and this implementation. Perhaps now with omnistaging, the use of cond
in the previously mentioned implementation would result in efficient code, based on this quote:
In addition to improving JAX’s memory performance, omnistaging enables a host of other improvements and simplifications throughout JAX. For example, it allows lax.cond to accept thunks, so that lax.cond(x > 0, lambda: f(y), lambda: g(z)) will only evaluate one of f(y) and g(z)
Are there any plans on supporting "variable-length" scans up to a max number of iterations?
We don't have this "off the shelf," but you could combine scan
with cond
(or select
) as in #3850, to mask out final iterations up to the maximum.
Regarding efficiency: with or without omnistaging, you could make sure a cond
carries out the computations that you'd expect by passing it explicit operands rather relying on closure-capture, e.g.:
lax.cond(x > 0, lambda yz: f(yz[0]), lambda yz: g(yz[1]), (y, z))
But indeed, with omnistaging, the simpler expression lax.cond(x > 0, lambda: f(y), lambda g(z))
will behave similarly.
Looking only at your original example: does using lax.fori_loop
instead of lax.scan
work? It seems to fit, since your scan
takes no inputs, and produces no output array, and since there's no reverse-mode autodiff taking place.
Thank you @froystig and apologies for my delayed reply. My understanding of jax
and automatic differentiation is limited, so, unfortunately, it takes me time to generate "minimal" examples that reflect the problems I am trying to solve.
My use case is more complicated than the minimal example I have listed above. The functions that are integrated depend on parameters which I am optimising. Thus reverse-mode autodiff is thus needed to efficiently get gradients.
However, it appears that lax.cond
is not able to make reverse-mode autodiff to run efficiently, i.e. to mask out final iterations up to the maximum iteration. This is illustrated in the following example:
Are these results expected? My system details are:
jax==0.2.1
jaxlib==0.1.55
MacOS 10.15.7
Python 3.7.7
@froystig should I close this?
Basically my question ended up being:
Is it possible to efficiently do backpropagation on "variable length" scans/loops?
and if the answer is "not possible" or "not planning to support this" then there is no reason to have this issue open.
Is it possible to efficiently do backpropagation on "variable length" scans/loops?
Reverse-mode AD with loops of dynamic trip count requires dynamic storage (corresponding to the loop iterations). That's at odds with our machine model under jit
, in which all memory allocations are statically known in the compiled function, so we don't support it. We're avoiding introducing a transformation that works only in some execution contexts. If we always had dynamic memory, we'd simply do reverse-mode of loops using a stack.
Thanks for sharing the more complete example. This looks like the correct experiment to try for a "masked scan." The resulting runtime increase is unfortunate. Note that the performance of loops and conditionals can vary quite a bit by device. We (and XLA) can keep this example in mind as we do performance work in the future.
The issue:
Doing
jit
tedscan
s with variablelength
s throwsjax.core.ConcretizationTypeError
as evidenced in the following minimal exampleMinimal example:
Suppose we want to solve an initial value problem
dx/dt = f(x)
via Euler integration. An implementation usingscan
is given below:Without the
jit
decorator aboveintegrate_euler
, the above example works fine. However, when includingjit
, it throws: