The Stan Math Library is a C++ template library for automatic differentiation of any order using forward, reverse, and mixed modes. It includes a range of built-in functions for probabilistic modeling, linear algebra, and equation solving.
Solving ODEs with their sensitivities is numerically very intensive. Currently, the sensitivity calculation requires to evaluate the Jacobian wt to states and parameters in each step of the ODE integration. Right now this is done using nested autodiff. Experiments suggest that a >2x speedup is possible when providing these Jacobians in analytic form.
This issue suggests to add functions to stan-math which allow the user to pass in the analytic Jacobians of the ODE RHS wrt to states and parameters
And Jy & Jtheta are the jacobians of the ode_rhs functor wrt. y and theta, respectively. The signature of the functors could be:
real [] : ode_rhs(real t, real[] y, real[] theta, real[] x_r, int[] x_i)real [,] : Jy(real t, real[] y, real[] theta, real[] x_r, int[] x_i)real [,] : Jtheta(real t, real[] y, real[] theta, real[] x_r, int[] x_i)
For efficiency considerations we may consider to pass the ode_rhs result into the Jacobians as well. There can easily be circumstances where this is beneficial and the function value anyway gets computed and will be available. So this is an alternative:
Another consideration is to let the Jacobian functions return eigen matrices instead. While it is more natural to return matrices as matrix types this would create some inconsistency.
How the correctness of the supplied Jacobian is assessed is unclear at this stage - if at all. From the perspective of the stan-math library it can make sense to assume that the supplied Jacobian is correct and no attempt is made to ensure it's correctness.
Expected Output
Same outputs, just a lot faster as nested AD is avoided during ODE integration.
Description
Solving ODEs with their sensitivities is numerically very intensive. Currently, the sensitivity calculation requires to evaluate the Jacobian wt to states and parameters in each step of the ODE integration. Right now this is done using nested autodiff. Experiments suggest that a >2x speedup is possible when providing these Jacobians in analytic form.
This issue suggests to add functions to
stan-math
which allow the user to pass in the analytic Jacobians of the ODE RHS wrt to states and parametersExample
Right now solving ODEs is done with the call
integrate_ode_algo(ode_rhs, y0, t0, times, theta, x_r, x_i, rel_tol, abs_tol, max_steps)
This feature suggests to add an additional signature per algorithm which is
integrate_ode_algo(ode_rhs, Jy, Jtheta, y0, t0, times, theta, x_r, x_i, rel_tol, abs_tol, max_steps)
And
Jy
&Jtheta
are the jacobians of theode_rhs
functor wrt.y
andtheta
, respectively. The signature of the functors could be:real [] : ode_rhs(real t, real[] y, real[] theta, real[] x_r, int[] x_i)
real [,] : Jy(real t, real[] y, real[] theta, real[] x_r, int[] x_i)
real [,] : Jtheta(real t, real[] y, real[] theta, real[] x_r, int[] x_i)
For efficiency considerations we may consider to pass the
ode_rhs
result into the Jacobians as well. There can easily be circumstances where this is beneficial and the function value anyway gets computed and will be available. So this is an alternative:real [] : ode_rhs(real t, real[] y, real[] theta, real[] x_r, int[] x_i)
real [,] : Jy(real t, real[] y, real[] ode_rhs, real[] theta, real[] x_r, int[] x_i)
real [,] : Jtheta(real t, real[] y, real[] ode_rhs, real[] theta, real[] x_r, int[] x_i)
Another consideration is to let the Jacobian functions return eigen matrices instead. While it is more natural to return matrices as matrix types this would create some inconsistency.
How the correctness of the supplied Jacobian is assessed is unclear at this stage - if at all. From the perspective of the
stan-math
library it can make sense to assume that the supplied Jacobian is correct and no attempt is made to ensure it's correctness.Expected Output
Same outputs, just a lot faster as nested AD is avoided during ODE integration.
Additional Information
Current Math Version
v2.18.0