google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.96k stars 2.75k forks source link

`lax.scan` ~100x slower than recursion? #2251

Open rtbs-dev opened 4 years ago

rtbs-dev commented 4 years ago

Working on inference on network infection cascades, eventually with numpyro. At the moment, trying to make a fastgenerative model, taking some cues from here on fast sequential loops. Note use of lax.scan.

However, it seems that (surprisingly) scan seems to run significantly slower than recursion(!).

Here's an example setup, I've tried to comment as best I could, quickly. Running in a notebook.

import numpy as onp
import jax.numpy as np
from jax.random import PRNGKey
# from jax.config import config
from jax import jit, grad, lax, random, vmap
from jax.ops import index_update, index, index_add

n_nodes = 50
n_edges = n_nodes*(n_nodes-1)//2

trans_times = onp.random.geometric(
    onp.random.beta(2,5,size=(n_edges,)),
    size=(n_edges,)
)

@jit 
def jax_squareform(edgelist, n=n_nodes):
    """edgelist to adj. matrix"""
    empty = np.zeros((n,n))
    half = index_add(empty, index[np.triu_indices(n,1)], edgelist)
    full = half+half.T
    return full

a = jax_squareform(trans_times)  # transition times
x0 = np.array([1]+(n_nodes-1)*[0])  # infect state

from collections import namedtuple
# (infected?, time-left-per-neighbor?)
InfectState = namedtuple('InfectState', ['x', 's_ij'])

@jit
def infect(state, step=1):
    neighbor_set = state.s_ij*state.x  # who knows an infected node?
    getting_infected=np.any(neighbor_set==1, axis=1) # and is getting infected now?
    x_p = lax.clamp(0,state.x+getting_infected, 1) # update infections
    s_ij_p = lax.max(state.s_ij - step*getting_infected, 0.) # and time-left
    return InfectState(x=x_p, s_ij=s_ij_p), step  # new state

So at this point, the individual time-steps are running real fast

>>>%timeit infect(InfectState(x0, a))
211 µs ± 3.03 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

Now let's make the loops:

def pandemic(state, step=1, t=0, T=5):
    state_p, _ = infect(state, step=step)

    if t==T:
        return state_p
    elif (t>=0) and (t<T):
        return pandemic(state_p, t=t+step)
    else:
        print('INVALID t!')

def pandemic_scan(state, step=1, t=0, T=5):
    return lax.scan(
        infect, 
        InfectState(x0, a), 
        np.full(T,step)
    )

So the difference is pretty stark:

>>> %timeit pandemic(InfectState(x0, a), T=5)
1.42 ms ± 4.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

>>> %timeit pandemic_scan(InfectState(x0, a), T=5)
164 ms ± 638 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

So here's the key questions:

  1. Is there a more idiomatic way to use scan in this case that avoids whatever slow-down is occurring? The ultimate use-case involves inference around a bunch of x(T) observations to estimate x0 for each x(T) and an overall s_ij given all of them. So presumably this need's to be fast.

  2. Is there a version of pandemic within the jax ecosystem that might allow jit-compilation? It seems that dependence on the boolean comparison of t<T is causing it to complain about static_argnums, etc.

shoyer commented 4 years ago

Can you try adding a jit decorator around pandemic_scan? I think every invocation may be invoking a new compilation.

rtbs-dev commented 4 years ago

@shoyer as-is putting an @jit decorator returns a TypeError:

TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int32[], weak_type=True):JaxprTrace(level=-1/1)>,).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
rtbs-dev commented 4 years ago

Also not sure if this is relevant to a diagnosis, but refactoring the infect function with sub-compilation seems to have made the difference even more stark, with the difference at nearly 300x

@jit
def countdown(s_ij, x, step):
    t_minus = s_ij - step*x
    sym = lax.min(t_minus, t_minus.T)
    return lax.max(sym, 0.)  # no neg. times

@jit
def infect(state, step=1):
    neighbor_set = state.s_ij*state.x  # who knows an infected node?
    getting_infected=np.any(neighbor_set==1, axis=1) # and is getting infected now?
    x_p = lax.clamp(0,state.x+getting_infected, 1) # update infections
    s_ij_p = countdown(state.s_ij, getting_infected, step)  # and time-left

    return InfectState(x=x_p, s_ij=s_ij_p), step  # new state
>>> %timeit pandemic_scan(InfectState(x0, a), T=5)
291 ms ± 13.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
>>> %timeit pandemic(InfectState(x0, a), T=5)
926 µs ± 29.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
shoyer commented 4 years ago

Try:

@functools.partial(jax.jit, static_argnums=(1, 2, 3))
def _pandemic_scan(state, step, t, T):
    return lax.scan(
        infect, 
        InfectState(x0, a), 
        np.full(T,step)
    )

def _pandemic_scan(state, step=1, t=0, T=5):
    return _pandemic_scan(state, step, t, T)
rtbs-dev commented 4 years ago

@shoyer Ok, so yeah that appears to have done it:

import functools

@functools.partial(jit, static_argnums=(1,2,3))
def _pandemic_scan(state, step, t, T):
    return lax.scan(
        infect, 
        InfectState(x0, a), 
        np.full(T,step)
    )

def pandemic_scan(state, step=1, t=0, T=5):
    return _pandemic_scan(state, step, t, T)
>>> %timeit pandemic_scan(InfectState(x0, a), T=5)
141 µs ± 19.7 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

So I'm going to have to ask, why jit-compiling upon function-call is faster? What exactly is going on here, and what prevents a trick like this from being integrated into the default behaviour of jit?

For the heck of it I tried a "local" version of this, where the "private" version is only defined in the outer function's scope:

def pandemic_scan(state, step=1, t=0, T=5):

    @functools.partial(jit, static_argnums=(1,2,3))
    def _pandemic_scan(state, step, t, T):
        return lax.scan(
            infect, 
            InfectState(x0, a), 
            np.full(T,step)
        )

    return _pandemic_scan(state, step, t, T)

and the result:

272 ms ± 4.32 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

So...what exactly is going on here? Is this documented behaviour?

Thanks again for your help!

shoyer commented 4 years ago

You can get a good sense of the problem if you run computation through Python profile, like %prun in IPython. You'll see that your code is getting compiled each time it's run, instead of reusing the same compiled code.

The immediate source of the problem here is that lax.scan effectively always calls jit on its function argument, but no reference to that function is saved. It's the same issue in your "local" version. Each jit is effectively being run from scratch, which means caching fails.

This definitely known behavior (and it's likely unavoidable) but it clearly isn't well documented. We can and should fix that! :)