jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.32k stars 2.78k forks source link

Add control flow operators for iterated functions #9960

Open carlosgmartin opened 2 years ago

carlosgmartin commented 2 years ago

Add the following control flow operators for iterated functions:

# iterate : ∀a. ℕ → (a → a) → (a → a)
def iterate(n, f, x):
    for _ in range(n):
        x = f(x)
    return x

# orbit : ∀a. ℕ → (a → a) → (a → [a])
def orbit(n, f, x):
    xs = [x]
    for _ in range(n):
        x = f(x)
        xs.append(x)
    return xs

iterate and fori_loop are mutually definable, but the former has a simpler signature and semantics that is more common in my experience, and is arguably more "basic" or "fundamental".

jakevdp commented 2 years ago

Interesting idea - if I understand your suggestion correctly, I think the implementations would look like this:

def iterate(n, f, x):
  return lax.scan(lambda x, _: (f(x), x), x, None, length=n)[0]

def orbit(n, f, x):
  return lax.scan(lambda x, _: (f(x), x), x, None, length=n + 1)[1]

Is that what you have in mind?

mattjj commented 2 years ago

For iterate we may want lax.scan(lambda x, _: (f(x), None), x, None, length=n)[0], i.e. don't have an extensive output.

carlosgmartin commented 2 years ago

@jakevdp @mattjj Looks right:

from jax.numpy import array
from jax.lax import scan

def iterate_1(n, f, x):
    for _ in range(n):
        x = f(x)
    return x

def iterate_2(n, f, x):
  return scan(lambda x, _: (f(x), x), x, None, length=n)[0]

def iterate_3(n, f, x):
    return scan(lambda x, _: (f(x), None), x, None, length=n)[0]

def orbit_1(n, f, x):
    xs = [x]
    for _ in range(n):
        x = f(x)
        xs.append(x)
    return array(xs)

def orbit_2(n, f, x):
    return scan(lambda x, _: (f(x), x), x, None, length=n + 1)[1]

n = 10
f = lambda x: 2 * x + 1
x = 3
print(iterate_1(n, f, x)) # 4095
print(iterate_2(n, f, x)) # 4095
print(iterate_3(n, f, x)) # 4095
print(orbit_1(n, f, x)) # [   3    7   15   31   63  127  255  511 1023 2047 4095]
print(orbit_2(n, f, x)) # [   3    7   15   31   63  127  255  511 1023 2047 4095]
carlosgmartin commented 2 years ago

Out of curiosity, can this function be implemented in terms of lax primitives?

from jax.numpy import array

# orbit_while : ∀a. (a → 𝔹) → (a → a) → (a → [a])
def orbit_while(p, f, x):
    xs = []
    while p(x):
        xs.append(x)
        x = f(x)
    return array(xs) # contains only elements satisfying p

p = lambda x: x < 2047
f = lambda x: 2 * x + 1
x = 3

print(orbit_while(p, f, x)) # [   3    7   15   31   63  127  255  511 1023]
jakevdp commented 2 years ago

A further question for discussion:are these new API's necessary if they can be implemented via a single call to scan? One might argue that iterate and orbit already exist in JAX, they're just called scan. What do you think?

carlosgmartin commented 2 years ago

@jakevdp I think it's convenient to have these useful helper functions in the library to make it easier for users to adopt JAX, and to save them the trouble of figuring out how to implement them in terms of scan. After all, I don't think it'd be good to take away all existing control flow functions that can be implemented in terms of scan.

jakevdp commented 2 years ago

Sure, but adding new APIs does not come without maintenance costs. It's true that while_loop and fori_loop can be implemented in terms of scan, but their implementations are far more involved than a single line.

As this is the first time I'm aware of receiving such a request, I would lean toward not including them in the package API, but I'm happy to change my mind if there are compelling reasons to do so.

soraros commented 2 years ago

I'd much rather have a bounded while loop.

carlosgmartin commented 2 years ago

@jakevdp I think this works as a single-line implementation of fori_loop:

from jax.lax import fori_loop, scan

def fori_loop_scan(a, b, f, x):
  return scan(lambda ix, _: ((ix[0] + 1, f(*ix)), None), (a, x), None, b - a)[0][1]

a = 2
b = 8
f = lambda i, x: 3 + 2 * x + 4 * i + i * x
x = 10

y1 = fori_loop(a, b, f, x)
y2 = fori_loop_scan(a, b, f, x)
print(y1) # 827986
print(y2) # 827986

Out of curiosity, what's the implementation of while_loop in terms of scan?

Also, is it possible to implement orbit_while in terms of lax primitives?

@soraros What do you mean?

jakevdp commented 2 years ago

while_loop doesn't lower to scan, actually, since scan requires a static number of iterations.

jakevdp commented 2 years ago

And orbit_while is not currently possible to express in JIT-compatible JAX, because it returns an array of dynamic length.

carlosgmartin commented 2 years ago

@jakevdp That's what I suspected. Thanks.

carlosgmartin commented 2 years ago

@soraros By a bounded while loop, do you mean something like this?

import jax

def bounded_while_loop(p, f, x, n):
    def g(i, x):
        return jax.lax.cond(p(x), f, lambda x: x, x)
    return jax.lax.fori_loop(0, n, g, x)

def p(x):
    return x < 10

def f(x):
    return x + 1

x = 0
print(jax.lax.while_loop(p, f, x)) # 10
print(bounded_while_loop(p, f, x, 100)) # 10
print(bounded_while_loop(p, f, x, 5)) # 5
carlosgmartin commented 1 year ago

@jakevdp What do you think of letting xs=None by default in scan? That pattern seems to occur often.

jakevdp commented 1 year ago

I think that could be an improvement – I'd want to hear opinions from other folks on the team

carlosgmartin commented 1 year ago

I'd also like to suggest the following functions:

def foldl(f: Callable, h, xs, length: Optional[int] = None):
    '''
    http://zvon.org/other/haskell/Outputprelude/foldl_f.html
    (a → b → a) → a → [b] → a
    Arguments:
    `f`: Function.
    `h`: Initial value.
    `x`: Inputs.
    Returns:
    Final value.
    '''
    def g(h, x):
        return f(h, x), None
    h, _ = lax.scan(g, h, xs, length)
    return h

def scanl(f: Callable, h, xs, length: Optional[int] = None):
    '''
    http://zvon.org/other/haskell/Outputprelude/scanl_f.html
    (a → b → a) → a → [b] → [a]
    Arguments:
    `f`: Function.
    `h`: Initial value.
    `x`: Inputs.
    Returns:
    Intermediate values.
    '''
    def g(h, x):
        return f(h, x), h
    h, hs = lax.scan(g, h, xs, length)
    return tree_map(lambda h, hs: jnp.concatenate((hs, h[None])), h, hs)

These are very general and useful abstractions for processing sequences.

I think it's a good idea to add common control-flow constructs like these to the standard library. This saves users the trouble of having to figure out how to implement them in terms of scan, which has a more complex interface. (The latter can be especially inconvenient for new users not accustomed to the purely-functional approach.)