stan-dev / math

The Stan Math Library is a C++ template library for automatic differentiation of any order using forward, reverse, and mixed modes. It includes a range of built-in functions for probabilistic modeling, linear algebra, and equation solving.
https://mc-stan.org
BSD 3-Clause "New" or "Revised" License
723 stars 183 forks source link

Ambiguity between `fwd` and `mix` signatures for `hessian()` #3056

Open andrjohns opened 2 months ago

andrjohns commented 2 months ago

Description

With the current signatures for hessian in fwd and mix, it is not possible to call the fwd implementation with double types.

mix/functor/hessian.hpp:

template <typename F>
void hessian(const F& f, const Eigen::Matrix<double, Eigen::Dynamic, 1>& x,
             double& fx, Eigen::Matrix<double, Eigen::Dynamic, 1>& grad,
             Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic>& H)

fwd/functor/hessian.hpp:

template <typename T, typename F>
void hessian(const F& f, const Eigen::Matrix<T, Eigen::Dynamic, 1>& x, T& fx,
             Eigen::Matrix<T, Eigen::Dynamic, 1>& grad,
             Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>& H)

Because the scalar type is explicit in the mix signature and templated in the fwd signature, calls to hessian with double types will always resolve to the mix implementation. This makes testing/validation (or even just use as an alternative) a bit of a hurdle.

I think any kind of fix for this would imply a breaking change (changing function names/arguments), so could be bundled in the 5.0 release (@SteveBronder)?

Current Version:

v4.8.1

SteveBronder commented 2 months ago

I think this makes sense to fix. Which signatures do you want to add/change? We would put this in 5.0 breaking changes

andrjohns commented 1 month ago

I think we want to keep the mix implementation being called by default, since it's more efficient, but add a way to force the fwd impl to be called.

I can think of two ways:

  1. Rename the fwd version to hessian_fwd/hessian_fvar/etc
  2. Add a boolean template parameter, defaulting to true, such that hessian<false>(...) resolves to the fwd impl

Thoughts/preferences?