patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.37k stars 124 forks source link

Frequent JIT-recompile of discrete_terminating_event #333

Open nikolas-claussen opened 10 months ago

nikolas-claussen commented 10 months ago

Hi,

I am running into a strange issue when using the diffrax.diffeqsolve with the discrete_terminating_event argument which I believe is due to a large number of JIT-recompiles, making execution time slow.

For context, I am solving an ODE until a stopping criterion occurs. Then, I make some modification to the arguments of the ODE, and restart it. Schematically:

my_event = diffrax.DiscreteTerminatingEvent(lambda: state, **kwargs: my_function(state.y, *kwargs["args"]))
tcurrent = t0
y0 = my_initial_condition
args = my_initial_args
while tcurrent < t1:
    solution = diffrax.diffeqsolve(term, solver, args=args, t0=tcurrent, dt0=dt0, y0=y0,
                                   discrete_terminating_event=my_event,
                                   stepsize_controller=stepsize_controller, max_steps=None)
    tcurrent = float(solution.ts)
    y0 = solution.ys[-1]
    args = modify_args(solution.ys[-1], args)

All the functions (my_function, modify_args, the function wrapped by term) are written in JAX and JITed. The first time I run the while loop - as a cell in a jupyter notebook - it takes approx. 10 seconds. When I run it again, with identical my_initial_condition, it is significantly faster, approx. 0.3s. I assume this difference is due to the JIT compilation overhead - no problem.

However, when I re-run this with a slightly modified initial condition, e.g. y0 = my_initial_condition+1e-5 I am back to 10s runtime. This is not good, because I want to run this code block for large number of times for different values of my_initial_condition. I ran the following tests to see what might be going on:

This has lead my to believe that diffrax JIT-recompiles the discrete_terminating_event every time integration is stopped due to an event. Is there a way to avoid this?

Best,

Nikolas

patrick-kidger commented 10 months ago

It's a little hard to tease out an explanation for each individual case you've tested, but fundamentally recompilations happen every time you pass in a new function, or new bool/int/float/complex (that isn't wrapped into a JAX array), or when you change the shape/dtype of an array. But one example that is straightforward to explain is when you put my_event inside the loop, then you are creating a fresh lambda function every time (and Python doesn't offer a way to detect that this looks identical to the previous lambda functions you've created), and so this is what causes recompilation.

Fundamentally, what you almost certainly want to do is to JIT your whole computation -- include the diffeqsolve -- and not just to JIT individual pieces. See point 1 in this guidance. You can convert your Python while loop into a jax.lax.while_loop to make this possible.

nikolas-claussen commented 10 months ago

Thanks a lot - that made it work. I realized in the process of JIT-ing the whole while loop that my modify_args was actually not JIT compatible. But based on your advice about the jax.lax-control flow operators, I was able to fix that.