stan-dev / math

The Stan Math Library is a C++ template library for automatic differentiation of any order using forward, reverse, and mixed modes. It includes a range of built-in functions for probabilistic modeling, linear algebra, and equation solving.
https://mc-stan.org
BSD 3-Clause "New" or "Revised" License
737 stars 185 forks source link

(inverse) FFT analytic adjoint-Jacobian #2740

Closed bob-carpenter closed 2 years ago

bob-carpenter commented 2 years ago

Description

Add adjoint-Jacobian specialization for reverse mode for the fast Fourier transform (FFT) and its inverse.

Example

FFT case

If y = fft(x), then the adjoint-Jacobian is just the inverse FFT applied to the adjoint of the result,

adjoint(x) += ifft(adjoint(y))

Inverse FFT case

If y = ifft(x), then the adjoint-Jacobian update rule is inverted,

adjoint(x) += fft(adjoint(y))

Expected Output

Still passes tests, but is faster and uses less memory.

Current Version:

v4.3.2

WardBrian commented 2 years ago

I think this is within the level that I could implement it, modulo two questions.

  1. First, conceptually, is it okay that the variables x and y in the formula adjoint(x) += fft(adjoint(y)) are complex-valued? I'm not sure exactly what our autodiff does here
  2. I think at the moment it's more or less impossible to write a reverse-mode specialization for complex valued matrices because we need to implement another overload for our adj() and val() operators on Eigen matrices. Right now if you have an Eigen::Matrix<std::complex<var>> and call .adj() on it, you get
    class std::complex<stan::math::var_value<double> >’ has no member named ‘vi_’
      160 |       operator()(T &v) const { return v.vi_->adj_; }

    @SteveBronder - am I write about this or is this a misunderstanding of the system on my part?

Here is essentially what I think should work, modulo that operator issue and me possibly getting a type or two wrong

template <typename V, require_eigen_vector_vt<is_complex, V>* = nullptr,
          require_var_t<base_type_t<value_type_t<V>>>* = nullptr>
inline plain_type_t<V> fft(const V& v) {
  if (unlikely(v.size() < 1)) {
    return plain_type_t<V>(v);
  }

  Eigen::FFT<base_type_t<V>> fft;

  arena_t<V> arena_v = v;
  arena_t<V> res = fft.fwd(arena_v.val().eval());

  reverse_pass_callback([arena_v, res, fft]() mutable {
    arena_v.adj().array() += fft.inv(res.adj().eval());
  });

  return plain_type_t<V>(res);
}
bob-carpenter commented 2 years ago

You can drop the .array() before the += if .adj() gives you a vector back.

I would just call our fft function or reconstruct fft locally rather than passing a copy of fft through the closure.

Also, we don't want to call this base FFT thing, we just want to call our own fft() function. Then we don't have to maintain a copy of fft in the closure.