patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.32k stars 121 forks source link

Inspection of gradient calculation #280

Open moesphere opened 11 months ago

moesphere commented 11 months ago

Hello!

Is there a tutorial on how to inspect gradient/hessian calculation when using diffrax?

I am trying to calculate gradients and hessians, but the returned values are always nans, and I do not know where to start "debugging".

patrick-kidger commented 11 months ago

So there's nothing special about Diffrax -- I'd recommend debugging NaNs in the same way as any JAX program.

First of all, one very common source of NaNs during autodiff is a missing double where trick. (Needing this trick is actually a general fact of autodifferentiation systems! Nothing JAX-specific here.)

Second, another common source is when doing a sqrt or log of a negative number. If your vector field includes either of these operations, then consider checking whether their input is negative. (Note that Diffrax reserves the right to query your vector field with any values for t and y, even those outside the region in which the ODE is solved -- some solvers will make queries outside those ranges. So your vector field must be robust to such queries.) You can easily check this with jax.debug.print.

If it's not any of those common errors, then your main tools to track this down are:

Generally speaking, NaN issues aren't usually too tricky to track down. Just bisect through your code until you find the operation that's producing them.

Moreover I'd encourage you to be willing to place these checks inside your copy of Diffrax's code. You can get the install location via import diffrax; print(diffrax.__file__). (Placing breakpoints within the library is often a quicker way of debugging, as you have finer-grained control over your bisection search.)

moesphere commented 11 months ago

Thank you for the fast and informative answer! I will look into it.

I have used the same model equations with cvodes from sundials (via casadi). The results from the gradient calclations, i.e., the parameter sensitivities, were successful and physically feasible. Does this information favor a potential debugging strategy from the ones you mentioned above?

patrick-kidger commented 11 months ago

That probably doesn't help, I'm afraid!