Open tfrerix opened 4 years ago
What hardware platform is this? CPU/GPU/TPU?
The plots are obtained on a CPU. They look similar but shifted on a GPU.
Add-on: with vmap
it takes about a factor 3-4 longer to compile on my machine, roughly 2 mins for 20 integration steps. Here is the vmapped version I have used:
ndims = 3
n_integration_steps = 20
start = jnp.arange(ndims, 0, -1.)
converged, results = jax.vmap(optimize, in_axes=(0,None))(jnp.tile(start[None], [10,1]), n_integration_steps)
I don't have anything helpful yet, but just wanted to chime in to say: thanks for raising this and for the very clear repro! We'd certainly like to support physics work, so we're keen to learn about issues like this and try to fix them!
cc @brianwa84 as FYI on TFP (EDIT: though it looks like this may be a JAX issue to solve)
Here's the version that does the vmap
inside jit
. It takes 100 seconds to compile in public Colab:
ndims = 3
n_integration_steps = 10
start = jnp.arange(ndims, 0, -1.)
optimize_vmap = jax.jit(jax.vmap(optimize, in_axes=(0, None)), static_argnums=(1,))
%time converged, results = optimize_vmap(start[None, :], n_integration_steps)
# CPU times: user 1min 39s, sys: 1.59 s, total: 1min 40s
# Wall time: 1min 39s
(subsequent invocations run in only 3 ms!)
It looks like the repeated tracing of the loss function is related to recomputing jvp in many places. Is there a way to cache JVP tracing?
loss Traced<ShapedArray(float32[3])>with<JVPTrace(level=2/2)>
with primal = Traced<ShapedArray(float32[3]):JaxprTrace(level=1/2)>
tangent = Traced<ShapedArray(float32[3]):JaxprTrace(level=1/2)> 7
loss Traced<ShapedArray(float32[3])>with<JVPTrace(level=2/2)>
with primal = Traced<ShapedArray(float32[3]):JaxprTrace(level=1/2)>
tangent = Traced<ShapedArray(float32[3]):JaxprTrace(level=1/2)> 7
loss Traced<ShapedArray(float32[3])>with<JVPTrace(level=4/2)>
with primal = Traced<ShapedArray(float32[3]):JaxprTrace(level=3/2)>
tangent = Traced<ShapedArray(float32[3]):JaxprTrace(level=3/2)> 7
loss Traced<ShapedArray(float32[3])>with<JVPTrace(level=6/2)>
with primal = Traced<ShapedArray(float32[3]):JaxprTrace(level=5/2)>
tangent = Traced<ShapedArray(float32[3]):JaxprTrace(level=5/2)> 7
loss Traced<ShapedArray(float32[3])>with<JVPTrace(level=7/2)>
with primal = Traced<ShapedArray(float32[3]):JaxprTrace(level=6/2)>
tangent = Traced<ShapedArray(float32[3]):JaxprTrace(level=6/2)> 7
loss Traced<ShapedArray(float32[3])>with<JVPTrace(level=8/2)>
with primal = Traced<ShapedArray(float32[3]):JaxprTrace(level=7/2)>
tangent = Traced<ShapedArray(float32[3]):JaxprTrace(level=7/2)> 7
loss Traced<ShapedArray(float32[3])>with<JVPTrace(level=6/2)>
with primal = Traced<ShapedArray(float32[3]):JaxprTrace(level=5/2)>
tangent = Traced<ShapedArray(float32[3]):JaxprTrace(level=5/2)> 7
Indeed, it looks like we do lots of retracing currently. I suspect there's even more of that due to repeated calls of the function (at least 4x, initial step size estimation and RK integration for both the forward and backward passes) inside the nested odeint
.
It sounds like caching JVP tracing will be easy after the omnistaging
branch lands -- which unfortunately is currently incompatible with TFP. See @mattjj's comment over in https://github.com/google/jax/pull/3370#issuecomment-663789953
Copying from that comment on 3370 for convenience, I think we just need to memoize this line to avoid retracing.
TFP is now (as best our tests show us) compatible w/ omnistaging, with you can turn on with jax.config.enable_omnistaging()
. Will be good to see if we can now get these speedups.
Running the code in the OP and adjusting n_integration_steps
with a print(f"loss traced! {n_integration_steps} {time.time()}")
at the top of the loss
function, it seems that
n_integration_steps
Here's one example execution:
$ time python issue3847.py
/usr/local/google/home/mattjj/packages/jax/jax/lib/xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
2020-09-15 20:52:30.413012: W tensorflow/stream_executor/platform/default/dso_loader.cc:59] Could not load dynamic library 'libcudart.so.10.1'; dlerror: libcudart.so.10.1: cannot open shared object file: No such file or directory
loss traced! 100 1600228351.4273221
loss traced! 100 1600228351.5238543
loss traced! 100 1600228351.639439
loss traced! 100 1600228351.7595377
loss traced! 100 1600228351.9571645
loss traced! 100 1600228352.0764115
loss traced! 100 1600228352.1986306
loss traced! 100 1600228352.3199167
loss traced! 100 1600228352.5411289
loss traced! 100 1600228352.6583369
________________________________________________________
Executed in 16.72 secs fish external
usr time 16.76 secs 0.00 micros 16.76 secs
sys time 0.75 secs 671.00 micros 0.75 secs
Why'd we think that tracing took significant time here? Could XLA:CPU compilation time have improved significantly for this application?
(I was hoping to add some quick memoization and get a big win, but now the win is not clear!)
Based on a hunch from @shoyer I tried testing with omnistaging disabled (since it's now on by default after #4038, see also #3370)
$ time env JAX_OMNISTAGING=0 python issue3847.py
/usr/local/google/home/mattjj/packages/jax/jax/lib/xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to C
PU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
2020-09-15 20:56:48.133112: W tensorflow/stream_executor/platform/default/dso_loader.cc:59] Could not load dynamic lib
rary 'libcudart.so.10.1'; dlerror: libcudart.so.10.1: cannot open shared object file: No such file or directory
loss traced! 100 1600228609.162424
loss traced! 100 1600228609.5745056
loss traced! 100 1600228610.3911376
loss traced! 100 1600228611.190639
loss traced! 100 1600228611.2361348
loss traced! 100 1600228611.2457585
loss traced! 100 1600228612.115048
loss traced! 100 1600228612.8506565
loss traced! 100 1600228613.7472556
loss traced! 100 1600228614.4756129
________________________________________________________
Executed in 28.30 secs fish external
usr time 28.19 secs 588.00 micros 28.19 secs
sys time 1.07 secs 162.00 micros 1.07 secs
I didn't expect omnistaging would help automatically here, but I guess it did.
@tfrerix can you check performance against GitHub master and let me know how things look to you? I can still look into memoization if necessary, but afaict it looks like the place to improve is in XLA compile times for this workload (rather than JAX tracing).
Here's the vmap
version, which shows at most ~1.5sec in tracing time for loss
:
$ time python issue3847.py
/usr/local/google/home/mattjj/packages/jax/jax/lib/xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
2020-09-15 21:44:57.182616: W tensorflow/stream_executor/platform/default/dso_loader.cc:59] Could not load dynamic library 'libcudart.so.10.1'; dlerror: libcudart.so.10.1: cannot open shared object file: No such file or directory
loss traced! 100 1600231498.1988375
loss traced! 100 1600231498.6347728
loss traced! 100 1600231498.7496977
loss traced! 100 1600231498.8695302
loss traced! 100 1600231498.9872
loss traced! 100 1600231499.106297
loss traced! 100 1600231499.318343
loss traced! 100 1600231499.4400368
loss traced! 100 1600231499.5684063
loss traced! 100 1600231499.6852055
________________________________________________________
Executed in 106.88 secs fish external
usr time 106.83 secs 695.00 micros 106.83 secs
sys time 0.82 secs 190.00 micros 0.82 secs
The pre-omnistaging version is faster here:
$ time env JAX_OMNISTAGING=0 python issue3847.py
/usr/local/google/home/mattjj/packages/jax/jax/lib/xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
2020-09-15 21:47:19.647851: W tensorflow/stream_executor/platform/default/dso_loader.cc:59] Could not load dynamic library 'libcudart.so.10.1'; dlerror: libcudart.so.10.1: cannot open shared object file: No such file or directory
loss traced! 100 1600231640.6922941
loss traced! 100 1600231642.1424437
loss traced! 100 1600231642.9642553
loss traced! 100 1600231643.761223
loss traced! 100 1600231643.8060021
loss traced! 100 1600231643.816165
loss traced! 100 1600231644.6975186
loss traced! 100 1600231645.4296548
loss traced! 100 1600231646.2049031
loss traced! 100 1600231647.0686193
________________________________________________________
Executed in 74.20 secs fish external
usr time 74.28 secs 633.00 micros 74.28 secs
sys time 1.12 secs 118.00 micros 1.12 secs
I suspect that omnistaging made tracing faster in some ways but, because it stages out more computations to XLA and thus gives XLA bigger programs, can make XLA's compilation times slower. I think the fix there is to work with the XLA:CPU team to get faster compiles.
@mattjj, thanks for the additional performance tests!
I tested the code in the OP with and without omnistaging for the following installs (note that I also took the latest tfp-nightly):
! pip install -U -q jaxlib==0.1.55
! pip install -U -q git+https://github.com/google/jax.git@8376d92049624bf0784647b17b1f09015acd0947
! pip install -U -q tfp-nightly==0.12.0.dev20200918
The plot below shows that it runs significantly faster with omnistaging. (on a CPU in a colab pro notebook)
I just wanted to note that this is still an issue even with omnistaging, particularly the vmap(optimize)
case:
ndims = 3
n_integration_steps = 10
batched_start = jnp.tile(start[None], [10,1])
optimize_vmap_jit = jax.jit(jax.vmap(optimize, in_axes=(0, None)), static_argnums=(1,))
I measure jax.device_get(optimize_vmap_jit(batched_start, n_integration_steps))
taking 1 min 36s on a TPU, but that only includes 5.8 seconds doing XLA compilation and about 10 ms doing the actual computation. The rest is (presumably) overhead due to JAX tracing, though the trace is so large that the TensorBoard profiler seems to be unable to fully load it.
The rest is (presumably) overhead due to JAX tracing, though the trace is so large that the TensorBoard profiler seems to be unable to fully load it.
We should measure the tracing time to verify that it is really time spent in tracing. IIUC in my experiments above, tracing time was very small. (Maybe we should have better logging and instrumentation for tracing...)
It looks like things have improved a little bit since I last tried this. Running my example above from https://github.com/google/jax/issues/3847#issuecomment-702845696 on a TPU, here are the profiling results I see:
or unexpanded:
At the top level, I see:
Googlers: see go/vmap-optimize-ode-jax-profile for details
@shoyer you are the best!
jaxpr_subcomp is where we translate from a jaxpr to HLO, i.e. we walk the jaxpr and make calls into the XLA builder API. The same caching that would speed up trace_to_jaxpr_final would also speed up jaxpr_subcomp (because if we produce the same jaxpr objects, by object id, in the first step, then we can cache the subcomputation building in the second step).
I want to optimize the initial condition of a dynamical system w.r.t. a criterion based on the system's trajectory. To this end, I use
jax.experimental.ode.odeint
inside oftfp.optimizer.lbfgs_minimize
(cf. TFP-on-JAX L-BFGS). This might sound very particular, but I suppose this will come up more often when using JAX in the context of physics models.However, jit-compilation is very slow, which impedes scaling-up. Is there a good reason for slow compilation or is there a way to speed this up?
Below you see a code example and two graphs that show the compile times and the length of the code generate during jit-compilation, i.e.,
len(str(jax.make_jaxpr(...)(...)))
, as a function ofodeint
integration steps. The curves have similar shape and become somewhat flat.