Open rtbs-dev opened 4 years ago
Can you try adding a jit decorator around pandemic_scan? I think every invocation may be invoking a new compilation.
@shoyer as-is putting an @jit
decorator returns a TypeError:
TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int32[], weak_type=True):JaxprTrace(level=-1/1)>,).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
Also not sure if this is relevant to a diagnosis, but refactoring the infect
function with sub-compilation seems to have made the difference even more stark, with the difference at nearly 300x
@jit
def countdown(s_ij, x, step):
t_minus = s_ij - step*x
sym = lax.min(t_minus, t_minus.T)
return lax.max(sym, 0.) # no neg. times
@jit
def infect(state, step=1):
neighbor_set = state.s_ij*state.x # who knows an infected node?
getting_infected=np.any(neighbor_set==1, axis=1) # and is getting infected now?
x_p = lax.clamp(0,state.x+getting_infected, 1) # update infections
s_ij_p = countdown(state.s_ij, getting_infected, step) # and time-left
return InfectState(x=x_p, s_ij=s_ij_p), step # new state
>>> %timeit pandemic_scan(InfectState(x0, a), T=5)
291 ms ± 13.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
>>> %timeit pandemic(InfectState(x0, a), T=5)
926 µs ± 29.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Try:
@functools.partial(jax.jit, static_argnums=(1, 2, 3))
def _pandemic_scan(state, step, t, T):
return lax.scan(
infect,
InfectState(x0, a),
np.full(T,step)
)
def _pandemic_scan(state, step=1, t=0, T=5):
return _pandemic_scan(state, step, t, T)
@shoyer Ok, so yeah that appears to have done it:
import functools
@functools.partial(jit, static_argnums=(1,2,3))
def _pandemic_scan(state, step, t, T):
return lax.scan(
infect,
InfectState(x0, a),
np.full(T,step)
)
def pandemic_scan(state, step=1, t=0, T=5):
return _pandemic_scan(state, step, t, T)
>>> %timeit pandemic_scan(InfectState(x0, a), T=5)
141 µs ± 19.7 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
So I'm going to have to ask, why jit-compiling upon function-call is faster? What exactly is going on here, and what prevents a trick like this from being integrated into the default behaviour of jit
?
For the heck of it I tried a "local" version of this, where the "private" version is only defined in the outer function's scope:
def pandemic_scan(state, step=1, t=0, T=5):
@functools.partial(jit, static_argnums=(1,2,3))
def _pandemic_scan(state, step, t, T):
return lax.scan(
infect,
InfectState(x0, a),
np.full(T,step)
)
return _pandemic_scan(state, step, t, T)
and the result:
272 ms ± 4.32 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
So...what exactly is going on here? Is this documented behaviour?
Thanks again for your help!
You can get a good sense of the problem if you run computation through Python profile, like %prun
in IPython. You'll see that your code is getting compiled each time it's run, instead of reusing the same compiled code.
The immediate source of the problem here is that lax.scan
effectively always calls jit
on its function argument, but no reference to that function is saved. It's the same issue in your "local" version. Each jit
is effectively being run from scratch, which means caching fails.
This definitely known behavior (and it's likely unavoidable) but it clearly isn't well documented. We can and should fix that! :)
Working on inference on network infection cascades, eventually with numpyro. At the moment, trying to make a fastgenerative model, taking some cues from here on fast sequential loops. Note use of
lax.scan
.However, it seems that (surprisingly)
scan
seems to run significantly slower than recursion(!).Here's an example setup, I've tried to comment as best I could, quickly. Running in a notebook.
So at this point, the individual time-steps are running real fast
Now let's make the loops:
So the difference is pretty stark:
So here's the key questions:
Is there a more idiomatic way to use scan in this case that avoids whatever slow-down is occurring? The ultimate use-case involves inference around a bunch of
x(T)
observations to estimatex0
for eachx(T)
and an overalls_ij
given all of them. So presumably this need's to be fast.Is there a version of
pandemic
within the jax ecosystem that might allow jit-compilation? It seems that dependence on the boolean comparison oft<T
is causing it to complain aboutstatic_argnums
, etc.