Open ahmadsalim opened 4 years ago
Thanks for raising this; that's too slow! Which backend is this on (CPU/GPU/TPU)? It might be faster or slower to compile on different backends (often slowest on CPU, fastest on TPU).
Thanks for taking a look! I am using the CPU backend currently, since I am on macOS 😄
This sounds a little bit like https://github.com/google/jax/issues/3847 (nested odeint inside an optimizer). I wonder if the fix suggested in https://github.com/google/jax/issues/3847#issuecomment-669485433 would help here, too.
Thanks! :)
@ahmadsalim we might now be in a better position to make progress on this. However, you code link doesn't seem to work for me. If you're still working on this, could you share a repro? (And if you're not, could you close the issue?)
@mattjj Thanks for your comment and for taking a look into this. I am glad that you are better able to take a look now.
I was cleaning up some stuff in my branch and forgot that I linked it here. I have updated the link to a more stable version.
The issue is a bit similar to #1769 .
I have a program where compilation takes roughly around 10-15 minutes on each run. The program is available here. The core idea of the program is that it uses NumPyro's support for Stochastic Variational Inference to learn the parameters of an Runge-Kutta ODE integrator-based simulator.
I have included the XLA dump here: xla_debug_predator_prey.zip
Do you have any idea why compilation could be so slow? I can try to minimize the example if I can get a hint on where this issue could potentially occur?
I appreciate all the help that I can get. Thanks a lot in advance!