Open larsgeb opened 4 years ago
This is not documented because JAX does not have that feature.
Can you try to use jax.xla_computation to see if that is fast? If so, then the bulk of the time is taken by the XLA compiler. I do not understand how caching is done by XLA.
What changes between invocations in your use case? If it is only the input data, perhaps you can run the JAX program in an infinite loop reading inputs from somewhere else. If you actually change code, then how do you know that you don't need to recompile?
As a general proposition, if JAX builds a very large XLA computation, XLA may take a long time to compile it. That's probably what's happening here. It's hard to say why your computation is so large without a runnable reproduction. The most common source is Python loops, which often have the effect of unrolling the computation and should be replaced with lax
loop constructs.
It would be possible to say more with a small runnable reproduction.
Thank you both! It is indeed a Python for loop, a minimal code is included below. Admittedly, it is maybe not as well implemented in JAX as it could be.
In every run, only the variable velocity_squared
varies, the simulation code itself is constant. All the JIT operations and xla_computation's are on the order of tens of microseconds. The return object can also be rewritten as a single vector, or integrated with observations for a scalar output (the function at the very end illustrates this.
import jax
import jax.numpy as np
from jax import grad, jit
import numpy as onp
import scipy.signal as signal
def jax_wavesolver(velocity_squared):
velocity_squared = np.asarray(velocity_squared)
stf = np.zeros((2000,))
# If scipy is not available, this statement can be removed. Should not influence the computation.
stf = jax.ops.index_update(
stf, jax.ops.index[:100], signal.ricker(100, 10)[:] / 10.0
)
p = np.zeros((500, 1))
p_prev = np.zeros((500, 1))
p_next = np.zeros((500, 1))
dt = 0.25
misfit = 0.0
receiver_1 = np.empty_like(stf)
receiver_2 = np.empty_like(stf)
receiver_3 = np.empty_like(stf)
for i in range(2000):
p_next = jax.ops.index_update(
p_next,
jax.ops.index[1:-1],
2 * p[1:-1]
- p_prev[1:-1]
+ dt * velocity_squared[1:-1] * (p[0:-2] - 2 * p[1:-1] + p[2:]),
)
p_next = jax.ops.index_update(p_next, jax.ops.index[50], stf[i])
receiver_1 = jax.ops.index_update(receiver_1, jax.ops.index[i], p_next[50,0])
receiver_2 = jax.ops.index_update(receiver_1, jax.ops.index[i], p_next[425,0])
receiver_3 = jax.ops.index_update(receiver_1, jax.ops.index[i], p_next[475,0])
p_prev = p
p = p_next
return receiver_1, receiver_2, receiver_3
velocity_squared = onp.ones((500, 1))
velocity_squared[250:, :] = 4.0
velocity_squared[300:, :] = 1.0
velocity_squared[350:, :] = 4.0
velocity_squared[400:, :] = 1.0
# IPython timer for performance, remove as needed.
%time J_ws = jit(jax_wavesolver)
%time jax.xla_computation(J_ws)
# This would give the 'observations', but currently extremely slow
observations = J_ws(velocity_squared)
def jax_wavesolver_misfit(velocity_squared, observations):
velocity_squared = np.asarray(velocity_squared)
stf = np.zeros((2000,))
stf = jax.ops.index_update(
stf, jax.ops.index[:100], signal.ricker(100, 10)[:] / 10.0
)
p = np.zeros((500, 1))
p_prev = np.zeros((500, 1))
p_next = np.zeros((500, 1))
dt = 0.25
misfit = 0.0
receiver_1_obs = np.asarray(observations[0])
receiver_2_obs = np.asarray(observations[1])
receiver_3_obs = np.asarray(observations[2])
for i in range(2000):
p_next = jax.ops.index_update(
p_next,
jax.ops.index[1:-1],
2 * p[1:-1]
- p_prev[1:-1]
+ dt * velocity_squared[1:-1] * (p[0:-2] - 2 * p[1:-1] + p[2:]),
)
p_next = jax.ops.index_update(p_next, jax.ops.index[50], stf[i])
misfit += (receiver_1_obs[i] - p_next[50]) ** 2
misfit += (receiver_2_obs[i] - p_next[425]) ** 2
misfit += (receiver_3_obs[i] - p_next[475]) ** 2
p_prev = p
p = p_next
return misfit
%time J_ws_misfit = jit(jax_wavesolver_misfit)
%time jax.xla_computation(J_ws_misfit)
# This would give the 'observations', but currently extremely slow
X = J_ws_misfit(velocity_squared, observations)
# This is what I am after ultimately:
gr_J_ws_misfit = jit(grad(jax_wavesolver_misfit))
G = gr_J_ws_misfit(velocity_squared, observations)
You need to use jax.lax.scan
. Otherwise JAX will unroll the whole loop.
As a general proposition, if JAX builds a very large XLA computation, XLA may take a long time to compile it. That's probably what's happening here. It's hard to say why your computation is so large without a runnable reproduction. The most common source is Python loops, which often have the effect of unrolling the computation and should be replaced with
lax
loop constructs.It would be possible to say more with a small runnable reproduction.
I think @larsgeb 's original question was not that the subsequent running within a single-time python file execution is slow, but every jitted function has to be re-compiled, when that file is closed and re-executed. That may become a performance bottleneck for larger problems, like his PDE simulator.
Actually, I had a similar problem in my application. Although the subsequent execution of my code is extremely fast, the first run costs ~2s. My question here is why the first run still needs to compile my jit decorated function. It'd be nice if that compilation procedure could be done at the time of function declaration (I know that could be hard, since the type/shape/.. of the inputs are not known at that time), or the compiled functions could be reserved and reused somehow.
Thank you both! It is indeed a Python for loop, a minimal code is included below. Admittedly, it is maybe not as well implemented in JAX as it could be.
In every run, only the variable
velocity_squared
varies, the simulation code itself is constant. All the JIT operations and xla_computation's are on the order of tens of microseconds. The return object can also be rewritten as a single vector, or integrated with observations for a scalar output (the function at the very end illustrates this.import jax import jax.numpy as np from jax import grad, jit import numpy as onp import scipy.signal as signal def jax_wavesolver(velocity_squared): velocity_squared = np.asarray(velocity_squared) stf = np.zeros((2000,)) # If scipy is not available, this statement can be removed. Should not influence the computation. stf = jax.ops.index_update( stf, jax.ops.index[:100], signal.ricker(100, 10)[:] / 10.0 ) p = np.zeros((500, 1)) p_prev = np.zeros((500, 1)) p_next = np.zeros((500, 1)) dt = 0.25 misfit = 0.0 receiver_1 = np.empty_like(stf) receiver_2 = np.empty_like(stf) receiver_3 = np.empty_like(stf) for i in range(2000): p_next = jax.ops.index_update( p_next, jax.ops.index[1:-1], 2 * p[1:-1] - p_prev[1:-1] + dt * velocity_squared[1:-1] * (p[0:-2] - 2 * p[1:-1] + p[2:]), ) p_next = jax.ops.index_update(p_next, jax.ops.index[50], stf[i]) receiver_1 = jax.ops.index_update(receiver_1, jax.ops.index[i], p_next[50,0]) receiver_2 = jax.ops.index_update(receiver_1, jax.ops.index[i], p_next[425,0]) receiver_3 = jax.ops.index_update(receiver_1, jax.ops.index[i], p_next[475,0]) p_prev = p p = p_next return receiver_1, receiver_2, receiver_3 velocity_squared = onp.ones((500, 1)) velocity_squared[250:, :] = 4.0 velocity_squared[300:, :] = 1.0 velocity_squared[350:, :] = 4.0 velocity_squared[400:, :] = 1.0 # IPython timer for performance, remove as needed. %time J_ws = jit(jax_wavesolver) %time jax.xla_computation(J_ws) # This would give the 'observations', but currently extremely slow observations = J_ws(velocity_squared) def jax_wavesolver_misfit(velocity_squared, observations): velocity_squared = np.asarray(velocity_squared) stf = np.zeros((2000,)) stf = jax.ops.index_update( stf, jax.ops.index[:100], signal.ricker(100, 10)[:] / 10.0 ) p = np.zeros((500, 1)) p_prev = np.zeros((500, 1)) p_next = np.zeros((500, 1)) dt = 0.25 misfit = 0.0 receiver_1_obs = np.asarray(observations[0]) receiver_2_obs = np.asarray(observations[1]) receiver_3_obs = np.asarray(observations[2]) for i in range(2000): p_next = jax.ops.index_update( p_next, jax.ops.index[1:-1], 2 * p[1:-1] - p_prev[1:-1] + dt * velocity_squared[1:-1] * (p[0:-2] - 2 * p[1:-1] + p[2:]), ) p_next = jax.ops.index_update(p_next, jax.ops.index[50], stf[i]) misfit += (receiver_1_obs[i] - p_next[50]) ** 2 misfit += (receiver_2_obs[i] - p_next[425]) ** 2 misfit += (receiver_3_obs[i] - p_next[475]) ** 2 p_prev = p p = p_next return misfit %time J_ws_misfit = jit(jax_wavesolver_misfit) %time jax.xla_computation(J_ws_misfit) # This would give the 'observations', but currently extremely slow X = J_ws_misfit(velocity_squared, observations) # This is what I am after ultimately: gr_J_ws_misfit = jit(grad(jax_wavesolver_misfit)) G = gr_J_ws_misfit(velocity_squared, observations)
Hi, have you solved your problem yet?
I'm also trying to solve a problem with lots of loops, and frequent data updating. I find using jax.lax.scan
and jax.lax.fori_loop
speeds up the compilation time for ~10-20 times in my application. Hope that trick could also help you.
I have a simulation of a PDE on which I would like to perform Bayesian inference. I can (easily, thank you guys!) compute the 'forward' and gradient of the PDE, more efficiently than NumPy can, but the first time use takes upwards of 30 minutes (think 500 degrees of freedom, time integration, etc. ). It would be nice to store the JIT'ted and finally compiled methods somehow, to make subsequent use faster, but this doesn't seem to appear in the documentation. Is this possible?
If needed, I can provide a complete working example.