Open gnecula opened 4 years ago
I can't say why it works, but I'm using while_loop
with jacrev
. Nevertheless I would like to see this enhancement, as it's a bit of a pain to take care of this in the loop state and cond/body functions.
I'd like to see this as well. Here's a possible implementation:
def while_loop(cond_fun, body_fun, init_val, max_iters=None):
if max_iters is None:
return lax.while_loop(cond_fun, body_fun, init_val)
else:
def f(val, _):
val = lax.cond(cond_fun(val), lambda: body_fun(val), lambda: val)
return val, None
val, _ = lax.scan(f, init_val, None, length=max_iters)
return val
def test_while_loop():
def cond_fun(x):
return x < 100
def body_fun(x):
return x * 2
init_val = 1
out_1 = lax.while_loop(cond_fun, body_fun, init_val)
out_2 = while_loop(cond_fun, body_fun, init_val, 20)
print(out_1, out_2)
JaxOpt's implementation has been ironed out: https://github.com/google/jaxopt/blob/main/jaxopt/_src/loop.py
Several forms of loops in JAX support reverse AD:
scan
,fori_loop
with constant bounds, which is syntactic sugar forscan
. I think it could be useful to have another syntactic sugar for bounded loops by adding a parametermax_iterations
towhile_loop
.