Open nikolas-claussen opened 10 months ago
It's a little hard to tease out an explanation for each individual case you've tested, but fundamentally recompilations happen every time you pass in a new function, or new bool/int/float/complex (that isn't wrapped into a JAX array), or when you change the shape/dtype of an array. But one example that is straightforward to explain is when you put my_event
inside the loop, then you are creating a fresh lambda
function every time (and Python doesn't offer a way to detect that this looks identical to the previous lambda functions you've created), and so this is what causes recompilation.
Fundamentally, what you almost certainly want to do is to JIT your whole computation -- include the diffeqsolve -- and not just to JIT individual pieces. See point 1 in this guidance. You can convert your Python while loop into a jax.lax.while_loop
to make this possible.
Thanks a lot - that made it work. I realized in the process of JIT-ing the whole while loop that my modify_args
was actually not JIT compatible. But based on your advice about the jax.lax
-control flow operators, I was able to fix that.
Hi,
I am running into a strange issue when using the
diffrax.diffeqsolve
with thediscrete_terminating_event
argument which I believe is due to a large number of JIT-recompiles, making execution time slow.For context, I am solving an ODE until a stopping criterion occurs. Then, I make some modification to the arguments of the ODE, and restart it. Schematically:
All the functions (
my_function
,modify_args
, the function wrapped byterm
) are written in JAX and JITed. The first time I run the while loop - as a cell in a jupyter notebook - it takes approx. 10 seconds. When I run it again, with identicalmy_initial_condition
, it is significantly faster, approx. 0.3s. I assume this difference is due to the JIT compilation overhead - no problem.However, when I re-run this with a slightly modified initial condition, e.g.
y0 = my_initial_condition+1e-5
I am back to 10s runtime. This is not good, because I want to run this code block for large number of times for different values ofmy_initial_condition
. I ran the following tests to see what might be going on:discrete_terminating_event
, then the problem is gone, even if I am still passing adiscrete_terminating_event
-argument (modified so as to never trigger a stop)my_event
inside the while loop, then I always get the ~10s execution time, even if I'm re-running the cell with identical inputs. I.e.:my_function
, the function insidemy_event
with different values ofy
orargs
, I do not trigger a JIT recompile.This has lead my to believe that
diffrax
JIT-recompiles thediscrete_terminating_event
every time integration is stopped due to an event. Is there a way to avoid this?Best,
Nikolas