jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.49k stars 2.8k forks source link

Very slow compile when doing stochastic variational inference on the parameters of an ODE integrator #3975

Open ahmadsalim opened 4 years ago

ahmadsalim commented 4 years ago

The issue is a bit similar to #1769 .

I have a program where compilation takes roughly around 10-15 minutes on each run. The program is available here. The core idea of the program is that it uses NumPyro's support for Stochastic Variational Inference to learn the parameters of an Runge-Kutta ODE integrator-based simulator.

I have included the XLA dump here: xla_debug_predator_prey.zip

Do you have any idea why compilation could be so slow? I can try to minimize the example if I can get a hint on where this issue could potentially occur?

I appreciate all the help that I can get. Thanks a lot in advance!

mattjj commented 4 years ago

Thanks for raising this; that's too slow! Which backend is this on (CPU/GPU/TPU)? It might be faster or slower to compile on different backends (often slowest on CPU, fastest on TPU).

ahmadsalim commented 4 years ago

Thanks for taking a look! I am using the CPU backend currently, since I am on macOS 😄

shoyer commented 4 years ago

This sounds a little bit like https://github.com/google/jax/issues/3847 (nested odeint inside an optimizer). I wonder if the fix suggested in https://github.com/google/jax/issues/3847#issuecomment-669485433 would help here, too.

ahmadsalim commented 4 years ago

Thanks! :)

mattjj commented 4 years ago

@ahmadsalim we might now be in a better position to make progress on this. However, you code link doesn't seem to work for me. If you're still working on this, could you share a repro? (And if you're not, could you close the issue?)

ahmadsalim commented 4 years ago

@mattjj Thanks for your comment and for taking a look into this. I am glad that you are better able to take a look now.

I was cleaning up some stuff in my branch and forgot that I linked it here. I have updated the link to a more stable version.