stan-dev / math

The Stan Math Library is a C++ template library for automatic differentiation of any order using forward, reverse, and mixed modes. It includes a range of built-in functions for probabilistic modeling, linear algebra, and equation solving.
https://mc-stan.org
BSD 3-Clause "New" or "Revised" License
750 stars 189 forks source link

Algebraic Solver Differentiation Speedup #2401

Closed jgaeb closed 3 years ago

jgaeb commented 3 years ago

Description

(This issue is a revival of #1257. Essentially all of what follows is cribbed from @charlesm93!)

It may be possible to improve how derivatives are propagated through the the algebraic solver. The basic idea is to use the implicit function theorem to get an “almost adjoint” method. Currently the vari class for the algebraic solver is implemented so that the Jacobian is calculated explicitly.

Using nested gradients (#1856), it is, however, possible to avoid constructing the entire Jacobian. Instead, using the method outlined here, it should be possible (in the notation of the linked post) to replace this computation with a single reverse mode sweep and n forward-mode sweeps (to calculate the partial derivatives of f with respect to v) and a matrix-vector solve. This will also eliminate a matrix-matrix solve, and additional speedup may come from getting the partial derivatives of f with respect to v from the solver itself rather than calculating them explicitly.

The plan is to rewrite stan::math::algebra_solver_vari to defer the calculation of the adjoints until .solve() is actually called. (This may require some tweaks to the Newton and Powell solver wrappers that use it as well.)

Expected Output

Propagating derivatives through implicit algebraic equations in reverse mode should be faster.

Discussion on Forum

https://discourse.mc-stan.org/t/algebraic-solver-differentiation-speedup/20845

Current Version:

v4.0.1

jgaeb commented 3 years ago

Closed by #2421!