jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.02k stars 2.75k forks source link

Support for variable length in jitted scans? #4473

Closed nrontsis closed 3 years ago

nrontsis commented 3 years ago

The issue:

Doing jitted scans with variable lengths throws jax.core.ConcretizationTypeError as evidenced in the following minimal example

Minimal example:

Suppose we want to solve an initial value problem dx/dt = f(x) via Euler integration. An implementation using scan is given below:

from jax import jit, partial
from jax.lax import scan
import jax.numpy as np

@partial(jit, static_argnums=(0,))
def integrate_euler(f, initial_state, time, dt):
    propagated_state, _ = scan(
        f=lambda state, _: (euler_integration_step(f, state, dt), None),
        init=initial_state,
        length=np.divide(time, dt).astype(int),
        xs=None,
    )
    return propagated_state

def euler_integration_step(f, state, dt):
    return state + dt*f(state)

# Example: solve dx/dt = 9x/10, for t=1.0, starting with x0 = [1, 2], and using a timestep of 0.01.
integrate_euler(
    lambda x: 0.9*x,
    initial_state=np.array([1., 2.]),
    time=1.0,
    dt=1e-2
)

Without the jit decorator above integrate_euler, the above example works fine. However, when including jit, it throws:

jax.core.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected (in `int` Try using `x.astype(int)` instead.).
Use transformation parameters such as `static_argnums` for `jit` to avoid tracing input values.
See `https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error`.
Encountered value: Traced<ShapedArray(int64[]):JaxprTrace(level=-1/1)>
froystig commented 3 years ago

This is an intentional requirement of scan. The constant length is what enables support of reverse-mode autodiff that still compiles via jit. Variable-length loops are available as lax.while_loop and lax.fori_loop. These will jit and support forward-mode autodiff.

nrontsis commented 3 years ago

Thanks for the reply.

Are there any plans on supporting "variable-length" scans up to a max number of iterations? I am tempted to ask due to this comment from @mattjj and this implementation. Perhaps now with omnistaging, the use of cond in the previously mentioned implementation would result in efficient code, based on this quote:

In addition to improving JAX’s memory performance, omnistaging enables a host of other improvements and simplifications throughout JAX. For example, it allows lax.cond to accept thunks, so that lax.cond(x > 0, lambda: f(y), lambda: g(z)) will only evaluate one of f(y) and g(z)

froystig commented 3 years ago

Are there any plans on supporting "variable-length" scans up to a max number of iterations?

We don't have this "off the shelf," but you could combine scan with cond (or select) as in #3850, to mask out final iterations up to the maximum.

Regarding efficiency: with or without omnistaging, you could make sure a cond carries out the computations that you'd expect by passing it explicit operands rather relying on closure-capture, e.g.:

lax.cond(x > 0, lambda yz: f(yz[0]), lambda yz: g(yz[1]), (y, z))

But indeed, with omnistaging, the simpler expression lax.cond(x > 0, lambda: f(y), lambda g(z)) will behave similarly.

Looking only at your original example: does using lax.fori_loop instead of lax.scan work? It seems to fit, since your scan takes no inputs, and produces no output array, and since there's no reverse-mode autodiff taking place.

nrontsis commented 3 years ago

Thank you @froystig and apologies for my delayed reply. My understanding of jax and automatic differentiation is limited, so, unfortunately, it takes me time to generate "minimal" examples that reflect the problems I am trying to solve.

My use case is more complicated than the minimal example I have listed above. The functions that are integrated depend on parameters which I am optimising. Thus reverse-mode autodiff is thus needed to efficiently get gradients.

However, it appears that lax.cond is not able to make reverse-mode autodiff to run efficiently, i.e. to mask out final iterations up to the maximum iteration. This is illustrated in the following example:

Minimal example: ```python from typing import Tuple from jax import jit, jacrev, grad, jacfwd from jax.lax import scan, cond import jax.numpy as np INTEGRATION_TIMEDELTA = 1e-1 MAX_ITER = 100 def integrate_with_euler( compute_derivatives: callable, initial_state: np.ndarray, dt: float, parameters: np.ndarray ) -> np.ndarray: print("compiling") return scan( f=lambda state_time, _: (scan_function(compute_derivatives, state_time, dt, parameters), None), init=(initial_state, 0.0), xs=None, length=MAX_ITER, )[0][0] def step(f: callable, state: np.ndarray, time: float, stop_time: float, params: np.ndarray) -> Tuple[np.ndarray, float]: next_time = np.minimum(time + INTEGRATION_TIMEDELTA, stop_time) next_state = state + (next_time - time) * f(state, time, params) return next_state, next_time def no_operation(f: callable, state: np.ndarray, time: float, stop_time: float, params: np.ndarray) -> Tuple[np.ndarray, float]: return state, time def scan_function(f: callable, state_time: Tuple[np.ndarray, float], stop_time: float, params: np.ndarray) -> Tuple[np.ndarray, float]: state, time = state_time next_state, next_time = cond( time < stop_time, lambda inputs: step(f, *inputs), lambda inputs: no_operation(f, *inputs), operand=(state, time, stop_time, params), ) return next_state, next_time integrate = jit(integrate_with_euler, static_argnums=(0,)) gradient = jit(grad(lambda f, x, dt, args: np.sum(integrate(f, x, dt, args)), argnums=-1), static_argnums=(0,)) ```
Benchmark code: Copy paste in `ipython` the following, after having run the above minimal example. ``` python import numpy as onp A = -1e-2*onp.eye(1000) def example_function(x: np.ndarray, t: float, param: np.ndarray) -> np.ndarray: return A@x + param dt1 = 1.0 dt2 = 10.0 x0 = onp.random.randn(1000) parameters = onp.random.randn(1000) # Compile integrate(example_function, x0, dt1, parameters).block_until_ready() gradient(example_function, x0, dt1, parameters).block_until_ready() # Profile print("Profiling integration %f vs %f" % (dt1, dt2)) %timeit integrate(example_function, x0, dt1, parameters).block_until_ready() %timeit integrate(example_function, x0, dt2, parameters).block_until_ready() print("Profiling gradient %f vs %f" % (dt1, dt2)) %timeit gradient(example_function, x0, dt1, parameters).block_until_ready() %timeit gradient(example_function, x0, dt2, parameters).block_until_ready() ```
Benchmark results: ```python Profiling integration 1.000000 vs 10.000000 3.5 ms ± 586 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 14.8 ms ± 970 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) Profiling gradient 1.000000 vs 10.000000 313 ms ± 2.84 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 468 ms ± 15.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) ```

Are these results expected? My system details are:

jax==0.2.1
jaxlib==0.1.55
MacOS 10.15.7
Python 3.7.7
nrontsis commented 3 years ago

@froystig should I close this?

Basically my question ended up being:

Is it possible to efficiently do backpropagation on "variable length" scans/loops?

and if the answer is "not possible" or "not planning to support this" then there is no reason to have this issue open.

froystig commented 3 years ago

Is it possible to efficiently do backpropagation on "variable length" scans/loops?

Reverse-mode AD with loops of dynamic trip count requires dynamic storage (corresponding to the loop iterations). That's at odds with our machine model under jit, in which all memory allocations are statically known in the compiled function, so we don't support it. We're avoiding introducing a transformation that works only in some execution contexts. If we always had dynamic memory, we'd simply do reverse-mode of loops using a stack.

Thanks for sharing the more complete example. This looks like the correct experiment to try for a "masked scan." The resulting runtime increase is unfortunate. Note that the performance of loops and conditionals can vary quite a bit by device. We (and XLA) can keep this example in mind as we do performance work in the future.