jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.18k stars 2.77k forks source link

Slow Python tracing time with odeint inside tfp.optimizer.lbfgs_minimize #3847

Open tfrerix opened 4 years ago

tfrerix commented 4 years ago

I want to optimize the initial condition of a dynamical system w.r.t. a criterion based on the system's trajectory. To this end, I use jax.experimental.ode.odeint inside of tfp.optimizer.lbfgs_minimize (cf. TFP-on-JAX L-BFGS). This might sound very particular, but I suppose this will come up more often when using JAX in the context of physics models.

However, jit-compilation is very slow, which impedes scaling-up. Is there a good reason for slow compilation or is there a way to speed this up?

Below you see a code example and two graphs that show the compile times and the length of the code generate during jit-compilation, i.e., len(str(jax.make_jaxpr(...)(...))), as a function of odeint integration steps. The curves have similar shape and become somewhat flat.

!pip install tfp-nightly==0.12.0.dev20200723

from functools import partial
import jax
import jax.numpy as jnp
from jax.experimental.ode import odeint
from tensorflow_probability.substrates import jax as tfp

def ode_func(x, t):
  """
  Time derivative of a linear ODE that decays the state.
  """
  return -0.1 * x

def loss(x, n_integration_steps):
  """
  Minimizing this function essentially just minimizes the norm of x.
  """
  t = jnp.asarray([n*0.1 for n in range(n_integration_steps)])
  y = odeint(ode_func, x, t)
  return jnp.sum(jnp.square(y))  

@partial(jax.jit, static_argnums=(1,))
def optimize(start, n_integration_steps):
  """
  Optimizes the initial condition of the ODE system using L-BFGS.
  """
  loss_fn = partial(loss, n_integration_steps=n_integration_steps)
  optim_results = tfp.optimizer.lbfgs_minimize(
        jax.value_and_grad(loss_fn),
        initial_position=start,
        num_correction_pairs=10,
        tolerance=1e-8)
  return optim_results.converged, optim_results.position

ndims = 3
n_integration_steps = 10
start = jnp.arange(ndims, 0, -1.)
converged, results = optimize(start, n_integration_steps)

compile_times jaxpr_str_len

hawkinsp commented 4 years ago

What hardware platform is this? CPU/GPU/TPU?

tfrerix commented 4 years ago

The plots are obtained on a CPU. They look similar but shifted on a GPU.

tfrerix commented 4 years ago

Add-on: with vmap it takes about a factor 3-4 longer to compile on my machine, roughly 2 mins for 20 integration steps. Here is the vmapped version I have used:

ndims = 3
n_integration_steps = 20
start = jnp.arange(ndims, 0, -1.)
converged, results = jax.vmap(optimize, in_axes=(0,None))(jnp.tile(start[None], [10,1]), n_integration_steps)
mattjj commented 4 years ago

I don't have anything helpful yet, but just wanted to chime in to say: thanks for raising this and for the very clear repro! We'd certainly like to support physics work, so we're keen to learn about issues like this and try to fix them!

cc @brianwa84 as FYI on TFP (EDIT: though it looks like this may be a JAX issue to solve)

shoyer commented 4 years ago

Here's the version that does the vmap inside jit. It takes 100 seconds to compile in public Colab:

ndims = 3
n_integration_steps = 10
start = jnp.arange(ndims, 0, -1.)

optimize_vmap = jax.jit(jax.vmap(optimize, in_axes=(0, None)), static_argnums=(1,))
%time converged, results = optimize_vmap(start[None, :], n_integration_steps)
# CPU times: user 1min 39s, sys: 1.59 s, total: 1min 40s
# Wall time: 1min 39s

(subsequent invocations run in only 3 ms!)

brianwa84 commented 4 years ago

It looks like the repeated tracing of the loss function is related to recomputing jvp in many places. Is there a way to cache JVP tracing?

loss Traced<ShapedArray(float32[3])>with<JVPTrace(level=2/2)>
  with primal = Traced<ShapedArray(float32[3]):JaxprTrace(level=1/2)>
       tangent = Traced<ShapedArray(float32[3]):JaxprTrace(level=1/2)> 7
loss Traced<ShapedArray(float32[3])>with<JVPTrace(level=2/2)>
  with primal = Traced<ShapedArray(float32[3]):JaxprTrace(level=1/2)>
       tangent = Traced<ShapedArray(float32[3]):JaxprTrace(level=1/2)> 7
loss Traced<ShapedArray(float32[3])>with<JVPTrace(level=4/2)>
  with primal = Traced<ShapedArray(float32[3]):JaxprTrace(level=3/2)>
       tangent = Traced<ShapedArray(float32[3]):JaxprTrace(level=3/2)> 7
loss Traced<ShapedArray(float32[3])>with<JVPTrace(level=6/2)>
  with primal = Traced<ShapedArray(float32[3]):JaxprTrace(level=5/2)>
       tangent = Traced<ShapedArray(float32[3]):JaxprTrace(level=5/2)> 7
loss Traced<ShapedArray(float32[3])>with<JVPTrace(level=7/2)>
  with primal = Traced<ShapedArray(float32[3]):JaxprTrace(level=6/2)>
       tangent = Traced<ShapedArray(float32[3]):JaxprTrace(level=6/2)> 7
loss Traced<ShapedArray(float32[3])>with<JVPTrace(level=8/2)>
  with primal = Traced<ShapedArray(float32[3]):JaxprTrace(level=7/2)>
       tangent = Traced<ShapedArray(float32[3]):JaxprTrace(level=7/2)> 7
loss Traced<ShapedArray(float32[3])>with<JVPTrace(level=6/2)>
  with primal = Traced<ShapedArray(float32[3]):JaxprTrace(level=5/2)>
       tangent = Traced<ShapedArray(float32[3]):JaxprTrace(level=5/2)> 7
shoyer commented 4 years ago

Indeed, it looks like we do lots of retracing currently. I suspect there's even more of that due to repeated calls of the function (at least 4x, initial step size estimation and RK integration for both the forward and backward passes) inside the nested odeint.

It sounds like caching JVP tracing will be easy after the omnistaging branch lands -- which unfortunately is currently incompatible with TFP. See @mattjj's comment over in https://github.com/google/jax/pull/3370#issuecomment-663789953

mattjj commented 4 years ago

Copying from that comment on 3370 for convenience, I think we just need to memoize this line to avoid retracing.

brianwa84 commented 4 years ago

TFP is now (as best our tests show us) compatible w/ omnistaging, with you can turn on with jax.config.enable_omnistaging(). Will be good to see if we can now get these speedups.

mattjj commented 4 years ago

Running the code in the OP and adjusting n_integration_steps with a print(f"loss traced! {n_integration_steps} {time.time()}") at the top of the loss function, it seems that

  1. loss tracing happens 10 times independent of the value of n_integration_steps
  2. loss tracing only takes a couple seconds of the total execution time
  3. the total execution time is not very high.

Here's one example execution:

$ time python issue3847.py
/usr/local/google/home/mattjj/packages/jax/jax/lib/xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to CPU.
  warnings.warn('No GPU/TPU found, falling back to CPU.')
2020-09-15 20:52:30.413012: W tensorflow/stream_executor/platform/default/dso_loader.cc:59] Could not load dynamic library 'libcudart.so.10.1'; dlerror: libcudart.so.10.1: cannot open shared object file: No such file or directory
loss traced! 100 1600228351.4273221
loss traced! 100 1600228351.5238543
loss traced! 100 1600228351.639439
loss traced! 100 1600228351.7595377
loss traced! 100 1600228351.9571645
loss traced! 100 1600228352.0764115
loss traced! 100 1600228352.1986306
loss traced! 100 1600228352.3199167
loss traced! 100 1600228352.5411289
loss traced! 100 1600228352.6583369

________________________________________________________
Executed in   16.72 secs   fish           external
   usr time   16.76 secs    0.00 micros   16.76 secs
   sys time    0.75 secs  671.00 micros    0.75 secs

Why'd we think that tracing took significant time here? Could XLA:CPU compilation time have improved significantly for this application?

(I was hoping to add some quick memoization and get a big win, but now the win is not clear!)

mattjj commented 4 years ago

Based on a hunch from @shoyer I tried testing with omnistaging disabled (since it's now on by default after #4038, see also #3370)

$ time env JAX_OMNISTAGING=0 python issue3847.py
/usr/local/google/home/mattjj/packages/jax/jax/lib/xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to C
PU.
  warnings.warn('No GPU/TPU found, falling back to CPU.')
2020-09-15 20:56:48.133112: W tensorflow/stream_executor/platform/default/dso_loader.cc:59] Could not load dynamic lib
rary 'libcudart.so.10.1'; dlerror: libcudart.so.10.1: cannot open shared object file: No such file or directory
loss traced! 100 1600228609.162424
loss traced! 100 1600228609.5745056
loss traced! 100 1600228610.3911376
loss traced! 100 1600228611.190639
loss traced! 100 1600228611.2361348
loss traced! 100 1600228611.2457585
loss traced! 100 1600228612.115048
loss traced! 100 1600228612.8506565
loss traced! 100 1600228613.7472556
loss traced! 100 1600228614.4756129

________________________________________________________
Executed in   28.30 secs   fish           external
   usr time   28.19 secs  588.00 micros   28.19 secs
   sys time    1.07 secs  162.00 micros    1.07 secs

I didn't expect omnistaging would help automatically here, but I guess it did.

mattjj commented 4 years ago

@tfrerix can you check performance against GitHub master and let me know how things look to you? I can still look into memoization if necessary, but afaict it looks like the place to improve is in XLA compile times for this workload (rather than JAX tracing).

mattjj commented 4 years ago

Here's the vmap version, which shows at most ~1.5sec in tracing time for loss:

$ time python issue3847.py
/usr/local/google/home/mattjj/packages/jax/jax/lib/xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to CPU.
  warnings.warn('No GPU/TPU found, falling back to CPU.')
2020-09-15 21:44:57.182616: W tensorflow/stream_executor/platform/default/dso_loader.cc:59] Could not load dynamic library 'libcudart.so.10.1'; dlerror: libcudart.so.10.1: cannot open shared object file: No such file or directory
loss traced! 100 1600231498.1988375
loss traced! 100 1600231498.6347728
loss traced! 100 1600231498.7496977
loss traced! 100 1600231498.8695302
loss traced! 100 1600231498.9872
loss traced! 100 1600231499.106297
loss traced! 100 1600231499.318343
loss traced! 100 1600231499.4400368
loss traced! 100 1600231499.5684063
loss traced! 100 1600231499.6852055

________________________________________________________
Executed in  106.88 secs   fish           external
   usr time  106.83 secs  695.00 micros  106.83 secs
   sys time    0.82 secs  190.00 micros    0.82 secs

The pre-omnistaging version is faster here:

$ time env JAX_OMNISTAGING=0 python issue3847.py
/usr/local/google/home/mattjj/packages/jax/jax/lib/xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to CPU.
  warnings.warn('No GPU/TPU found, falling back to CPU.')
2020-09-15 21:47:19.647851: W tensorflow/stream_executor/platform/default/dso_loader.cc:59] Could not load dynamic library 'libcudart.so.10.1'; dlerror: libcudart.so.10.1: cannot open shared object file: No such file or directory
loss traced! 100 1600231640.6922941
loss traced! 100 1600231642.1424437
loss traced! 100 1600231642.9642553
loss traced! 100 1600231643.761223
loss traced! 100 1600231643.8060021
loss traced! 100 1600231643.816165
loss traced! 100 1600231644.6975186
loss traced! 100 1600231645.4296548
loss traced! 100 1600231646.2049031
loss traced! 100 1600231647.0686193

________________________________________________________
Executed in   74.20 secs   fish           external
   usr time   74.28 secs  633.00 micros   74.28 secs
   sys time    1.12 secs  118.00 micros    1.12 secs

I suspect that omnistaging made tracing faster in some ways but, because it stages out more computations to XLA and thus gives XLA bigger programs, can make XLA's compilation times slower. I think the fix there is to work with the XLA:CPU team to get faster compiles.

tfrerix commented 4 years ago

@mattjj, thanks for the additional performance tests!

I tested the code in the OP with and without omnistaging for the following installs (note that I also took the latest tfp-nightly):

! pip install -U -q jaxlib==0.1.55
! pip install -U -q git+https://github.com/google/jax.git@8376d92049624bf0784647b17b1f09015acd0947
! pip install -U -q tfp-nightly==0.12.0.dev20200918

The plot below shows that it runs significantly faster with omnistaging. (on a CPU in a colab pro notebook)

compile_times_omnistaging

shoyer commented 4 years ago

I just wanted to note that this is still an issue even with omnistaging, particularly the vmap(optimize) case:

ndims = 3
n_integration_steps = 10
batched_start = jnp.tile(start[None], [10,1])
optimize_vmap_jit = jax.jit(jax.vmap(optimize, in_axes=(0, None)), static_argnums=(1,))

I measure jax.device_get(optimize_vmap_jit(batched_start, n_integration_steps)) taking 1 min 36s on a TPU, but that only includes 5.8 seconds doing XLA compilation and about 10 ms doing the actual computation. The rest is (presumably) overhead due to JAX tracing, though the trace is so large that the TensorBoard profiler seems to be unable to fully load it.

mattjj commented 3 years ago

The rest is (presumably) overhead due to JAX tracing, though the trace is so large that the TensorBoard profiler seems to be unable to fully load it.

We should measure the tracing time to verify that it is really time spent in tracing. IIUC in my experiments above, tracing time was very small. (Maybe we should have better logging and instrumentation for tracing...)

shoyer commented 3 years ago

It looks like things have improved a little bit since I last tried this. Running my example above from https://github.com/google/jax/issues/3847#issuecomment-702845696 on a TPU, here are the profiling results I see:

image

or unexpanded:

image

At the top level, I see:

Googlers: see go/vmap-optimize-ode-jax-profile for details

mattjj commented 3 years ago

@shoyer you are the best!

jaxpr_subcomp is where we translate from a jaxpr to HLO, i.e. we walk the jaxpr and make calls into the XLA builder API. The same caching that would speed up trace_to_jaxpr_final would also speed up jaxpr_subcomp (because if we produce the same jaxpr objects, by object id, in the first step, then we can cache the subcomputation building in the second step).