jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.44k stars 2.8k forks source link

Implement an option for lax.while_loop to specify the maximum number of iterations, to allow reverse differentiation #2469

Open gnecula opened 4 years ago

gnecula commented 4 years ago

Several forms of loops in JAX support reverse AD: scan, fori_loop with constant bounds, which is syntactic sugar for scan. I think it could be useful to have another syntactic sugar for bounded loops by adding a parameter max_iterations to while_loop.

tavin commented 2 years ago

I can't say why it works, but I'm using while_loop with jacrev. Nevertheless I would like to see this enhancement, as it's a bit of a pain to take care of this in the loop state and cond/body functions.

carlosgmartin commented 8 months ago

I'd like to see this as well. Here's a possible implementation:

def while_loop(cond_fun, body_fun, init_val, max_iters=None):
    if max_iters is None:
        return lax.while_loop(cond_fun, body_fun, init_val)
    else:

        def f(val, _):
            val = lax.cond(cond_fun(val), lambda: body_fun(val), lambda: val)
            return val, None

        val, _ = lax.scan(f, init_val, None, length=max_iters)
        return val

def test_while_loop():
    def cond_fun(x):
        return x < 100

    def body_fun(x):
        return x * 2

    init_val = 1

    out_1 = lax.while_loop(cond_fun, body_fun, init_val)
    out_2 = while_loop(cond_fun, body_fun, init_val, 20)
    print(out_1, out_2)
NeilGirdhar commented 8 months ago

JaxOpt's implementation has been ironed out: https://github.com/google/jaxopt/blob/main/jaxopt/_src/loop.py