Open carlosgmartin opened 2 years ago
Interesting idea - if I understand your suggestion correctly, I think the implementations would look like this:
def iterate(n, f, x):
return lax.scan(lambda x, _: (f(x), x), x, None, length=n)[0]
def orbit(n, f, x):
return lax.scan(lambda x, _: (f(x), x), x, None, length=n + 1)[1]
Is that what you have in mind?
For iterate
we may want lax.scan(lambda x, _: (f(x), None), x, None, length=n)[0]
, i.e. don't have an extensive output.
@jakevdp @mattjj Looks right:
from jax.numpy import array
from jax.lax import scan
def iterate_1(n, f, x):
for _ in range(n):
x = f(x)
return x
def iterate_2(n, f, x):
return scan(lambda x, _: (f(x), x), x, None, length=n)[0]
def iterate_3(n, f, x):
return scan(lambda x, _: (f(x), None), x, None, length=n)[0]
def orbit_1(n, f, x):
xs = [x]
for _ in range(n):
x = f(x)
xs.append(x)
return array(xs)
def orbit_2(n, f, x):
return scan(lambda x, _: (f(x), x), x, None, length=n + 1)[1]
n = 10
f = lambda x: 2 * x + 1
x = 3
print(iterate_1(n, f, x)) # 4095
print(iterate_2(n, f, x)) # 4095
print(iterate_3(n, f, x)) # 4095
print(orbit_1(n, f, x)) # [ 3 7 15 31 63 127 255 511 1023 2047 4095]
print(orbit_2(n, f, x)) # [ 3 7 15 31 63 127 255 511 1023 2047 4095]
Out of curiosity, can this function be implemented in terms of lax primitives?
from jax.numpy import array
# orbit_while : ∀a. (a → 𝔹) → (a → a) → (a → [a])
def orbit_while(p, f, x):
xs = []
while p(x):
xs.append(x)
x = f(x)
return array(xs) # contains only elements satisfying p
p = lambda x: x < 2047
f = lambda x: 2 * x + 1
x = 3
print(orbit_while(p, f, x)) # [ 3 7 15 31 63 127 255 511 1023]
A further question for discussion:are these new API's necessary if they can be implemented via a single call to scan
? One might argue that iterate
and orbit
already exist in JAX, they're just called scan
. What do you think?
@jakevdp I think it's convenient to have these useful helper functions in the library to make it easier for users to adopt JAX, and to save them the trouble of figuring out how to implement them in terms of scan
. After all, I don't think it'd be good to take away all existing control flow functions that can be implemented in terms of scan
.
Sure, but adding new APIs does not come without maintenance costs. It's true that while_loop
and fori_loop
can be implemented in terms of scan
, but their implementations are far more involved than a single line.
As this is the first time I'm aware of receiving such a request, I would lean toward not including them in the package API, but I'm happy to change my mind if there are compelling reasons to do so.
I'd much rather have a bounded while loop.
@jakevdp I think this works as a single-line implementation of fori_loop
:
from jax.lax import fori_loop, scan
def fori_loop_scan(a, b, f, x):
return scan(lambda ix, _: ((ix[0] + 1, f(*ix)), None), (a, x), None, b - a)[0][1]
a = 2
b = 8
f = lambda i, x: 3 + 2 * x + 4 * i + i * x
x = 10
y1 = fori_loop(a, b, f, x)
y2 = fori_loop_scan(a, b, f, x)
print(y1) # 827986
print(y2) # 827986
Out of curiosity, what's the implementation of while_loop
in terms of scan
?
Also, is it possible to implement orbit_while
in terms of lax primitives?
@soraros What do you mean?
while_loop
doesn't lower to scan
, actually, since scan
requires a static number of iterations.
And orbit_while
is not currently possible to express in JIT-compatible JAX, because it returns an array of dynamic length.
@jakevdp That's what I suspected. Thanks.
@soraros By a bounded while loop, do you mean something like this?
import jax
def bounded_while_loop(p, f, x, n):
def g(i, x):
return jax.lax.cond(p(x), f, lambda x: x, x)
return jax.lax.fori_loop(0, n, g, x)
def p(x):
return x < 10
def f(x):
return x + 1
x = 0
print(jax.lax.while_loop(p, f, x)) # 10
print(bounded_while_loop(p, f, x, 100)) # 10
print(bounded_while_loop(p, f, x, 5)) # 5
@jakevdp What do you think of letting xs=None
by default in scan? That pattern seems to occur often.
I think that could be an improvement – I'd want to hear opinions from other folks on the team
I'd also like to suggest the following functions:
def foldl(f: Callable, h, xs, length: Optional[int] = None):
'''
http://zvon.org/other/haskell/Outputprelude/foldl_f.html
(a → b → a) → a → [b] → a
Arguments:
`f`: Function.
`h`: Initial value.
`x`: Inputs.
Returns:
Final value.
'''
def g(h, x):
return f(h, x), None
h, _ = lax.scan(g, h, xs, length)
return h
def scanl(f: Callable, h, xs, length: Optional[int] = None):
'''
http://zvon.org/other/haskell/Outputprelude/scanl_f.html
(a → b → a) → a → [b] → [a]
Arguments:
`f`: Function.
`h`: Initial value.
`x`: Inputs.
Returns:
Intermediate values.
'''
def g(h, x):
return f(h, x), h
h, hs = lax.scan(g, h, xs, length)
return tree_map(lambda h, hs: jnp.concatenate((hs, h[None])), h, hs)
These are very general and useful abstractions for processing sequences.
I think it's a good idea to add common control-flow constructs like these to the standard library. This saves users the trouble of having to figure out how to implement them in terms of scan
, which has a more complex interface. (The latter can be especially inconvenient for new users not accustomed to the purely-functional approach.)
Add the following control flow operators for iterated functions:
iterate
andfori_loop
are mutually definable, but the former has a simpler signature and semantics that is more common in my experience, and is arguably more "basic" or "fundamental".