jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.49k stars 2.8k forks source link

Is there a way to store (serialize) expensive first-time use JIT'ted methods? #2249

Open larsgeb opened 4 years ago

larsgeb commented 4 years ago

I have a simulation of a PDE on which I would like to perform Bayesian inference. I can (easily, thank you guys!) compute the 'forward' and gradient of the PDE, more efficiently than NumPy can, but the first time use takes upwards of 30 minutes (think 500 degrees of freedom, time integration, etc. ). It would be nice to store the JIT'ted and finally compiled methods somehow, to make subsequent use faster, but this doesn't seem to appear in the documentation. Is this possible?

If needed, I can provide a complete working example.

gnecula commented 4 years ago

This is not documented because JAX does not have that feature.

Can you try to use jax.xla_computation to see if that is fast? If so, then the bulk of the time is taken by the XLA compiler. I do not understand how caching is done by XLA.

What changes between invocations in your use case? If it is only the input data, perhaps you can run the JAX program in an infinite loop reading inputs from somewhere else. If you actually change code, then how do you know that you don't need to recompile?

hawkinsp commented 4 years ago

As a general proposition, if JAX builds a very large XLA computation, XLA may take a long time to compile it. That's probably what's happening here. It's hard to say why your computation is so large without a runnable reproduction. The most common source is Python loops, which often have the effect of unrolling the computation and should be replaced with lax loop constructs.

It would be possible to say more with a small runnable reproduction.

larsgeb commented 4 years ago

Thank you both! It is indeed a Python for loop, a minimal code is included below. Admittedly, it is maybe not as well implemented in JAX as it could be.

In every run, only the variable velocity_squared varies, the simulation code itself is constant. All the JIT operations and xla_computation's are on the order of tens of microseconds. The return object can also be rewritten as a single vector, or integrated with observations for a scalar output (the function at the very end illustrates this.

import jax
import jax.numpy as np
from jax import grad, jit
import numpy as onp
import scipy.signal as signal

def jax_wavesolver(velocity_squared):

    velocity_squared = np.asarray(velocity_squared)

    stf = np.zeros((2000,))

    # If scipy is not available, this statement can be removed. Should not influence the computation.
    stf = jax.ops.index_update(
        stf, jax.ops.index[:100], signal.ricker(100, 10)[:] / 10.0
    )

    p = np.zeros((500, 1))
    p_prev = np.zeros((500, 1))
    p_next = np.zeros((500, 1))

    dt = 0.25

    misfit = 0.0

    receiver_1 = np.empty_like(stf)
    receiver_2 = np.empty_like(stf)
    receiver_3 = np.empty_like(stf)

    for i in range(2000):

        p_next = jax.ops.index_update(
            p_next,
            jax.ops.index[1:-1],
            2 * p[1:-1]
            - p_prev[1:-1]
            + dt * velocity_squared[1:-1] * (p[0:-2] - 2 * p[1:-1] + p[2:]),
        )

        p_next = jax.ops.index_update(p_next, jax.ops.index[50], stf[i])

        receiver_1 = jax.ops.index_update(receiver_1, jax.ops.index[i], p_next[50,0])
        receiver_2 = jax.ops.index_update(receiver_1, jax.ops.index[i], p_next[425,0])
        receiver_3 = jax.ops.index_update(receiver_1, jax.ops.index[i], p_next[475,0])

        p_prev = p
        p = p_next

    return receiver_1, receiver_2, receiver_3

velocity_squared = onp.ones((500, 1))
velocity_squared[250:, :] = 4.0
velocity_squared[300:, :] = 1.0
velocity_squared[350:, :] = 4.0
velocity_squared[400:, :] = 1.0

# IPython timer for performance, remove as needed.

%time J_ws = jit(jax_wavesolver)

%time jax.xla_computation(J_ws)

# This would give the 'observations', but currently extremely slow
observations = J_ws(velocity_squared)

def jax_wavesolver_misfit(velocity_squared, observations):

    velocity_squared = np.asarray(velocity_squared)

    stf = np.zeros((2000,))
    stf = jax.ops.index_update(
        stf, jax.ops.index[:100], signal.ricker(100, 10)[:] / 10.0
    )

    p = np.zeros((500, 1))
    p_prev = np.zeros((500, 1))
    p_next = np.zeros((500, 1))

    dt = 0.25

    misfit = 0.0

    receiver_1_obs = np.asarray(observations[0])
    receiver_2_obs = np.asarray(observations[1])
    receiver_3_obs = np.asarray(observations[2])

    for i in range(2000):

        p_next = jax.ops.index_update(
            p_next,
            jax.ops.index[1:-1],
            2 * p[1:-1]
            - p_prev[1:-1]
            + dt * velocity_squared[1:-1] * (p[0:-2] - 2 * p[1:-1] + p[2:]),
        )

        p_next = jax.ops.index_update(p_next, jax.ops.index[50], stf[i])
        misfit += (receiver_1_obs[i] - p_next[50]) ** 2
        misfit += (receiver_2_obs[i] - p_next[425]) ** 2
        misfit += (receiver_3_obs[i] - p_next[475]) ** 2

        p_prev = p
        p = p_next

    return misfit

%time J_ws_misfit = jit(jax_wavesolver_misfit)

%time jax.xla_computation(J_ws_misfit)

# This would give the 'observations', but currently extremely slow
X = J_ws_misfit(velocity_squared, observations)

# This is what I am after ultimately:
gr_J_ws_misfit = jit(grad(jax_wavesolver_misfit))
G = gr_J_ws_misfit(velocity_squared, observations)
yingted commented 4 years ago

You need to use jax.lax.scan. Otherwise JAX will unroll the whole loop.

YukunXia commented 4 years ago

As a general proposition, if JAX builds a very large XLA computation, XLA may take a long time to compile it. That's probably what's happening here. It's hard to say why your computation is so large without a runnable reproduction. The most common source is Python loops, which often have the effect of unrolling the computation and should be replaced with lax loop constructs.

It would be possible to say more with a small runnable reproduction.

I think @larsgeb 's original question was not that the subsequent running within a single-time python file execution is slow, but every jitted function has to be re-compiled, when that file is closed and re-executed. That may become a performance bottleneck for larger problems, like his PDE simulator.

Actually, I had a similar problem in my application. Although the subsequent execution of my code is extremely fast, the first run costs ~2s. My question here is why the first run still needs to compile my jit decorated function. It'd be nice if that compilation procedure could be done at the time of function declaration (I know that could be hard, since the type/shape/.. of the inputs are not known at that time), or the compiled functions could be reserved and reused somehow.

YukunXia commented 4 years ago

Thank you both! It is indeed a Python for loop, a minimal code is included below. Admittedly, it is maybe not as well implemented in JAX as it could be.

In every run, only the variable velocity_squared varies, the simulation code itself is constant. All the JIT operations and xla_computation's are on the order of tens of microseconds. The return object can also be rewritten as a single vector, or integrated with observations for a scalar output (the function at the very end illustrates this.

import jax
import jax.numpy as np
from jax import grad, jit
import numpy as onp
import scipy.signal as signal

def jax_wavesolver(velocity_squared):

    velocity_squared = np.asarray(velocity_squared)

    stf = np.zeros((2000,))

    # If scipy is not available, this statement can be removed. Should not influence the computation.
    stf = jax.ops.index_update(
        stf, jax.ops.index[:100], signal.ricker(100, 10)[:] / 10.0
    )

    p = np.zeros((500, 1))
    p_prev = np.zeros((500, 1))
    p_next = np.zeros((500, 1))

    dt = 0.25

    misfit = 0.0

    receiver_1 = np.empty_like(stf)
    receiver_2 = np.empty_like(stf)
    receiver_3 = np.empty_like(stf)

    for i in range(2000):

        p_next = jax.ops.index_update(
            p_next,
            jax.ops.index[1:-1],
            2 * p[1:-1]
            - p_prev[1:-1]
            + dt * velocity_squared[1:-1] * (p[0:-2] - 2 * p[1:-1] + p[2:]),
        )

        p_next = jax.ops.index_update(p_next, jax.ops.index[50], stf[i])

        receiver_1 = jax.ops.index_update(receiver_1, jax.ops.index[i], p_next[50,0])
        receiver_2 = jax.ops.index_update(receiver_1, jax.ops.index[i], p_next[425,0])
        receiver_3 = jax.ops.index_update(receiver_1, jax.ops.index[i], p_next[475,0])

        p_prev = p
        p = p_next

    return receiver_1, receiver_2, receiver_3

velocity_squared = onp.ones((500, 1))
velocity_squared[250:, :] = 4.0
velocity_squared[300:, :] = 1.0
velocity_squared[350:, :] = 4.0
velocity_squared[400:, :] = 1.0

# IPython timer for performance, remove as needed.

%time J_ws = jit(jax_wavesolver)

%time jax.xla_computation(J_ws)

# This would give the 'observations', but currently extremely slow
observations = J_ws(velocity_squared)

def jax_wavesolver_misfit(velocity_squared, observations):

    velocity_squared = np.asarray(velocity_squared)

    stf = np.zeros((2000,))
    stf = jax.ops.index_update(
        stf, jax.ops.index[:100], signal.ricker(100, 10)[:] / 10.0
    )

    p = np.zeros((500, 1))
    p_prev = np.zeros((500, 1))
    p_next = np.zeros((500, 1))

    dt = 0.25

    misfit = 0.0

    receiver_1_obs = np.asarray(observations[0])
    receiver_2_obs = np.asarray(observations[1])
    receiver_3_obs = np.asarray(observations[2])

    for i in range(2000):

        p_next = jax.ops.index_update(
            p_next,
            jax.ops.index[1:-1],
            2 * p[1:-1]
            - p_prev[1:-1]
            + dt * velocity_squared[1:-1] * (p[0:-2] - 2 * p[1:-1] + p[2:]),
        )

        p_next = jax.ops.index_update(p_next, jax.ops.index[50], stf[i])
        misfit += (receiver_1_obs[i] - p_next[50]) ** 2
        misfit += (receiver_2_obs[i] - p_next[425]) ** 2
        misfit += (receiver_3_obs[i] - p_next[475]) ** 2

        p_prev = p
        p = p_next

    return misfit

%time J_ws_misfit = jit(jax_wavesolver_misfit)

%time jax.xla_computation(J_ws_misfit)

# This would give the 'observations', but currently extremely slow
X = J_ws_misfit(velocity_squared, observations)

# This is what I am after ultimately:
gr_J_ws_misfit = jit(grad(jax_wavesolver_misfit))
G = gr_J_ws_misfit(velocity_squared, observations)

Hi, have you solved your problem yet?

I'm also trying to solve a problem with lots of loops, and frequent data updating. I find using jax.lax.scan and jax.lax.fori_loop speeds up the compilation time for ~10-20 times in my application. Hope that trick could also help you.