patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable.
Apache License 2.0
1.38k stars 124 forks source link

Jacobian matrix for ODE solvers #375

Open lsptrsn opened 7 months ago

lsptrsn commented 7 months ago


is it possible to extract the Jacobian matrix for the ODE solvers (I mean specifically Kvaerno5)?

Many thanks in advance.

patrick-kidger commented 7 months ago

Which Jacobian do you want specifically?

To be precise: given an ODE dy/dt = f(y), then the nth implicit Runge--Kutta stage solves for the u_n such that

F_n(u_n) = 0


F_n(u_n) = u_n - f(y0 + (Σ_i a_i u_i + a_n u_n) dt)

given value y0 at the start of a step of length dt, with Butcher tableau entries a_i. (Assuming I haven't made a typo anyway.)

The implicit solvers used for this problem work by computing the Jacobian dF_n/du_n. SDIRK and ESDIRK solvers (like Kvaerno5) would usually evaluated this just once at y0. (DIRK solvers would usually evaluate it again every stage.) Is this the Jacobian you're after?

If so then at least within the code then the function F_n is given here.

Computing this Jacobian is something that is offloaded to the choice of root finder. (e.g. Kvaerno5().root_finder) At least for the default diffrax.VeryChord root finder, then the Jacobian is computed here and is thus available as state.linear_state[0], although that exact location is a private implementation detail. You could grab that by wrapping the root-finder and intercepting it in the root-finder's init method.

Does that answer your question?

lsptrsn commented 7 months ago

That answers my question. Thank you very much!