stan-dev / math

The Stan Math Library is a C++ template library for automatic differentiation of any order using forward, reverse, and mixed modes. It includes a range of built-in functions for probabilistic modeling, linear algebra, and equation solving.
https://mc-stan.org
BSD 3-Clause "New" or "Revised" License
735 stars 185 forks source link

Add functions to write hessian vector products into an array #2949

Open roualdes opened 11 months ago

roualdes commented 11 months ago

The functions

  1. stan::math::internal::finite_diff_hessian_times_vector_auto()
  2. stan::math::hessian_times_vector()

accept the hessian vector product (hvp or Hv, respectively) as an Eigen::VectorXd& out parameter. Similar to issue #2739, it would benefit BridgeStan to allow signatures that effectively write into double* hvp such that BridgeStan could offer cheaper hessian vector products.

I propose adding functions with signatures

  1. stan::math::internal::finite_diff_hessian_times_vector_auto()
template <typename F, typename EigVec, typename InputIt, 
  require_eigen_vector_vt<std::is_arithmetic, EigVec>* = nullptr>
void finite_diff_hessian_times_vector_auto(const F& f, const EigVec& x,
                                           const EigVec& v, double& fx, InputIt first_hvp, InputIt last_hvp)
  1. stan::math::hessian_times_vector()
template <typename F, typename EigVec, typename InputIt, 
  require_eigen_vector_vt<std::is_arithmetic, EigVec>* = nullptr>
void hessian_times_vector(const F& f,
                          const EigVec& x,
                          const EigVec& v,
                          double& fx,
                          InputIt first_Hv, InputIt last_Hv) 

In fact, stan::math::hessian could benefit from a similar added function, but I'm having a hard time wrapping my head around the design of the signature since the out parameter here is a matrix. @WardBrian, would you mind lending an opinion?

Current Version:

v4.7.0

WardBrian commented 11 months ago

The hessian could be represented as a strided matrix (essentially a rowise or columnwise flattening). Since it should always be symmetric it doesn’t even matter which direction you flatten in