stan-dev / math

The Stan Math Library is a C++ template library for automatic differentiation of any order using forward, reverse, and mixed modes. It includes a range of built-in functions for probabilistic modeling, linear algebra, and equation solving.
https://mc-stan.org
BSD 3-Clause "New" or "Revised" License
733 stars 184 forks source link

Complex support for more linear algebra functions #2699

Open WardBrian opened 2 years ago

WardBrian commented 2 years ago

This is based on https://github.com/stan-dev/stanc3/issues/1162

The following functions do not fully compile in expression tests when added to stanc. I've attached gists which show the compiler errors.

spinkney commented 2 years ago

@WardBrian and @bob-carpenter for the generalized_inverse the following mathoverflow post https://mathoverflow.net/questions/25778/analytical-formula-for-numerical-derivative-of-the-matrix-pseudo-inverse for the derivatives

For complex matrices, the above formula works if Hermitian conjugates are used instead of transposes. I don't have any reference on this (anyone?), but this is verified by all the numerical tests I did (with matrices of various shapes and ranks).

bob-carpenter commented 2 years ago

Thanks, @spinkney. Right now, we're just autodiffing through the Eigen implementations of complex functions, but this is what we'll need to add analytic derivatives.

SteveBronder commented 2 years ago

@bob-carpenter can you take a look at the branch below for determinant? One of the test matrices is failing and idk enough about complex numbers to know why

https://github.com/stan-dev/math/compare/feature/more-complex-matrix-funs?expand=1

git pull
git checkout feature/more-complex-matrix-funs
./runTests.py -j3 ./test/unit/math/mix/fun/determinant_test

specifically matrix e in the ad test framework fails when checking the hessian using fvar types

  complex_matrix e(2, 2);
  e << complex_d(0), complex_d(1), complex_d(2), complex_d(3);

All the other ones pass though. You can see a printout of the gradient and hessian calcs from failures by adding -DSTAN_TEST_PRINT_MATRIX_FAILURE to CXXFLAGS

bob-carpenter commented 2 years ago

I think there's a general problem with forward-mode autodiff of complex numbers interacting with Eigen. They're failing in my FFT PR as well when all of the reverse mode things pass. What's odd is that forward-mode works just fine with all the scalar complex operations, which is why I'm puzzled why it's failing with matrix operations.

You can use R to check what the answers should be.

> a <- matrix(c(0, 1, 2, 3), 2, 2, byrow=TRUE)

> a
     [,1] [,2]
[1,]    0    1
[2,]    2    3

> det(a)
[1] -2

In general, for things with only real components, the answer should be the same as the real operation. I have no idea what the derivatives look like here---it could be failing there. But it shouldn't if you can do the real valued equivalent of this operation.

SteveBronder commented 2 years ago

Bob I'm looking over the AD framework and I'm a little confused on how it handles serialization and deserialization for complex types. Like for test_grad_hessian in the docs the value to test is x, but x is always an Eigen::VectorXd. What is the path to parse complex inputs happening at? The reason I'm checking is because I wanted to make sure we are making a std::complex<fvar<double>> and not a fvar<std::complex<double>> etc.

WardBrian commented 2 years ago

Found some more signatures we would probably like to support but that don't compile:

bob-carpenter commented 2 years ago

Found some more signatures we would probably like to support but that don't compile:

Does the other order, subtract(complex_matrix, complex), work?

bob-carpenter commented 2 years ago

Like for test_grad_hessian in the docs the value to test is x, but x is always an Eigen::VectorXd. What is the path to parse complex inputs happening at?

The top-level call delegates to expect_ad_v:

template <typename F, typename T>
void expect_ad(const ad_tolerances& tols, const F& f, const T& x) {
  internal::expect_ad_v(tols, f, x);
}

then expect_ad_v knocks everything down into real-valued scalars with serialization so that the function g just operates on real-valued vectors v:

template <typename F, typename T>
void expect_ad_v(const ad_tolerances& tols, const F& f, const T& x) {
  auto g = [&](const auto& v) {
    auto ds = to_deserializer(v);
    auto xds = ds.read(x);
    return serialize_return(eval(f(xds)));
  };
  internal::expect_ad_helper(tols, f, g, serialize_args(x), x);
}

The internal helper does an eval with the original arguments xs and with serialized form of arguments x (always a VectorXd because complex gets broken down into real and imaginary components), and tests that the result is the same evaluated both ways and checks for throw behavior, and if there aren't any throws, it checks the derivatives using a function h which projects the original function into a one-dimensional output.

template <typename F, typename G, typename... Ts>
void expect_ad_helper(const ad_tolerances& tols, const F& f, const G& g,
                      const Eigen::VectorXd& x, Ts... xs) {
  auto h
      = [&](const int i) { return [&g, i](const auto& v) { return g(v)[i]; }; };
  size_t result_size = 0;
  try {
    auto y1 = eval(f(xs...));  // original types, including int
    auto y2 = eval(g(x));      // all int cast to double
    auto y1_serial = serialize<double>(y1);
    expect_near_rel("expect_ad_helper", y1_serial, y2, 1e-10);
    result_size = y1_serial.size();
  } catch (...) {
    internal::expect_all_throw(h(0), x);
    return;
  }
  for (size_t i = 0; i < result_size; ++i) {
    expect_ad_derivatives(tols, h(i), x);
  }
}

So I don't think this is the problem. Also, the Hessians and everything work fine for the scalar complex functions---it's only the FFT I'm having trouble with.

Edit: forgot the final version, function, which just calls all the specific tests with the serialized function,

template <typename G>
void expect_ad_derivatives(const ad_tolerances& tols, const G& g,
                           const Eigen::VectorXd& x) {
  double gx = g(x);
  test_gradient(tols, g, x, gx);
#ifndef STAN_MATH_TESTS_REV_ONLY
  test_gradient_fvar(tols, g, x, gx);
  test_hessian(tols, g, x, gx);
  test_hessian_fvar(tols, g, x, gx);
  test_grad_hessian(tols, g, x, gx);
#endif
}
WardBrian commented 2 years ago

Does the other order, subtract(complex_matrix, complex), work?

No, the error is essentially the same

bob-carpenter commented 2 years ago

I can't figure out what's going wrong with the FFT tests. They work for reverse mode, but have a bunch of failures in forward mode, including their use for Hessians, etc.

But I think it's OK to release with this just working for reverse mode, but I don't know how to go about testing for that. There's an environment variable STAN_MATH_TESTS_REV_ONLY, but I want this to apply to just the FFT and inverse FFT functions.

I'm going to add a feature where the tests are only run if the tolerance is less than +infinity. That'll let me proceed without changing any signatures and only having to modify the expect_ad_derivatives function. It's a bit of a hack, because technically a high tolerance should still check exception behavior, but I'm not going to do that.

WardBrian commented 2 years ago

Does adding specializations for the analytic derivatives of FFT/iFFT not fix the test issues?

bob-carpenter commented 2 years ago

I'm not sure what's going on with the automation. I'd prefer to start new issues than modify old ones. you can ignore the chatter about FFTs as it's not really relevant here.

WardBrian commented 2 years ago

I accidentally hit 'Close with comment' instead of 'Comment' which triggered all of the above.

WardBrian commented 2 years ago

Vectorized versions of abs/fabs do not currently work for complex containers, @bob-carpenter thinks he might know why

bob-carpenter commented 2 years ago

@bob-carpenter thinks he might know why

That's easy---it doesn't work because the vectorization code was commented out. I'll try to sort it out on the C++ side, but might needs some help with template traits if there winds up being a lot of ambiguity.

WardBrian commented 2 years ago

A lot of these were resolved by #2753 and will be exposed in the language by https://github.com/stan-dev/stanc3/pull/1212. I believe the remaining all need their templates broadened, which I hope to at least attempt before the feature freeze.

WardBrian commented 2 years ago

Of the remaining ones in this issue, inverse and mdivide_left are just issues with using std::is_arithmetic. Adding an identical overload which uses stan::is_complex works, but autodiff is slow.

For determinant, the same sort of thing does compile, but fails AD tests.