autodiff / autodiff

automatic differentiation made easier for C++
https://autodiff.github.io
MIT License
1.65k stars 172 forks source link

Reverse mode asymptotically too slow O(N^14)? #332

Open alecjacobson opened 6 months ago

alecjacobson commented 6 months ago

Apologies that I haven't simplified this example more:

#include <autodiff/forward/dual.hpp>
#include <autodiff/reverse/var.hpp>
#include <autodiff/reverse/var/eigen.hpp>

// llt
#include <Eigen/Core>
#include <Eigen/Cholesky>

// tictoc function to return time since last call in seconds
#include <chrono>
double tictoc()
{
  static std::chrono::time_point<std::chrono::high_resolution_clock> last = std::chrono::high_resolution_clock::now();
  auto now = std::chrono::high_resolution_clock::now();
  double elapsed = std::chrono::duration_cast<std::chrono::microseconds>(now - last).count() * 1e-6;
  last = now;
  return elapsed;
}

template <typename T, int N>
T f(const Eigen::Matrix<T,N,1> & x)
{
  Eigen::Matrix<T,N,N> A;
  for (int i = 0; i < N; ++i)
  {
    for (int j = 0; j < N; ++j)
    {
      A(i,j) = x(i) + x(j);
    }
    A(i,i) = 2.0*A(i,i);
  }
  auto sol = A.llt().solve(x).eval();
  return sol.norm();
}

#include <iostream>

template <int N>
double call_and_time()
{
  Eigen::Matrix<autodiff::Variable<double> ,N,1> x(N);
  // initialize x to 1,2,3,
  for (int i = 0; i < N; ++i)
    x(i) = 1.0 + i;
  printf("%d ",N);
  tictoc();
  double _ = 0;
  for(int i = 0; i < 100; ++i)
  {
    _ += f(x.template cast<double>().eval());
  }
  printf("%g ",tictoc()/100);
  tictoc();
  // finite difference gradient
  Eigen::Matrix<double,N,1> dydx;
  {
    Eigen::Matrix<double,N,1> xval = x.template cast<double>().eval();
    for (int i = 0; i < N; ++i)
    {
      Eigen::Matrix<double,N,1> xplus = xval;
      xplus(i) += 1e-8;
      Eigen::Matrix<double,N,1> xminus = xval;
      xminus(i) -= 1e-8;
      dydx(i) = (f(xplus) - f(xminus)) / 2e-8;
    }
    printf("%g ",tictoc());
  }
  auto y = f(x);
  auto dydx_ad = gradient(y, x);
  printf("%g ",tictoc());
  printf("%g ",(dydx-dydx_ad.template cast<double>()).norm());
  printf("\n");
  return _;
}

int main(int argc, char *argv[])
{
  call_and_time<4>();
  call_and_time<5>();
  call_and_time<6>();
  call_and_time<7>();
  call_and_time<8>();
  call_and_time<9>();
  call_and_time<10>();
  call_and_time<11>();
  call_and_time<12>();
}

On my machine this prints:

4 2e-07 1e-06 0.000337 4.75513e-09 
5 4.5e-07 2e-06 0.003025 8.62449e-09 
6 4.6e-07 5e-06 0.025268 5.39745e-08 
7 4.9e-07 6e-06 0.109612 1.04286e-07 
8 3.7e-07 4e-06 1.03983 6.32531e-06 
9 5.6e-07 6e-06 9.11204 6.65647e-06 
10 4.3e-07 7e-06 33.2274 6.71564e-06 
11 4.6e-07 9e-06 125.264 6.62523e-06 
12 5e-07 1.1e-05 512.593 6.59076e-06 

The forward evaluation could be O(N³) and that'd make finite differences O(N⁴), but these casual timings are well under the radar for that.

What stands out is fitting a slope to the reverse-mode gradient call which looks like N14.

The .llt().solve seems to trigger some really bad complexity issues here.

I'm sure forward mode would be fine in this case, but in my larger project I am really counting on reverse-mode.

allanleal commented 6 months ago

Thank you for reporting this. I'm thinking about deprecating the number type var (for reverse automatic differentiation in autodiff) since it was initially an experiment with some follow-up contributions from users, but I never had the time to code a definitely faster reverse algorithm. For my use case I only need forward AD, so more attention was paid to the implementation of real and dual number types.

alecjacobson commented 6 months ago

Make sense. I really really like how easy it is to call and use this library. But the performance is unfortunately a deal breaker.

fwiw, I tried using a simple reverse mode implementation based on https://github.com/Rookfighter/autodiff-cpp/blob/master/include/adcpp.h and it appears to hit the same performance explosion on the case above. So perhaps there's a common pitfall here. I'd love to understand what it is.

Meanwhile, after a long battle with cmake I got https://mc-stan.org/math/ working. Here's an analogous benchmark:

// THIS MUST BE INCLUDED BEFORE ANY EIGEN HEADERS
#include <stan/math.hpp>

#include <Eigen/Dense>
#include <iostream>
// tictoc
#include <chrono>
double tictoc()
{
  double t = std::chrono::duration<double>(
      std::chrono::system_clock::now().time_since_epoch()).count();
  static double t0 = t;
  double tdiff = t-t0;
  t0 = t;
  return tdiff;
}

template <typename T, int N>
T llt_func( Eigen::Matrix<T, N, 1> & x)
{
  Eigen::Matrix<T, N, N> A;
  for(int i = 0; i < N; i++)
  {
    for(int j = 0; j < N; j++)
    {
      A(i, j) = x(i)+x(j);
    }
    A(i,i) = 2.0*A(i,i);
  }
  Eigen::Matrix<T, N, 1> b = A.llt().solve(x);
  T y = b.squaredNorm();
  return y;
}

// stan::math::gradient only supports Eigen::Dynamic
template <typename T, int N>
T llt_func_helper( Eigen::Matrix<T, Eigen::Dynamic, 1> & _x)
{
  Eigen::Matrix<T, N, 1> x = _x.template head<N>();
  return llt_func<T,N>(x);
}

template <int N, int max_N>
void benchmark()
{
  tictoc();
  const int max_iter = 1000;
  Eigen::Matrix<double, Eigen::Dynamic, 1> dydx(N);
  double yt = 0;
  for(int iter = 0;iter < max_iter;iter++)
  {
    Eigen::Matrix<double, Eigen::Dynamic, 1> x(N);
    for(int i  = 0; i < N; i++)
    {
      x(i) = i+1;
    }
    double y;
    stan::math::gradient(llt_func_helper<stan::math::var,N>, x, y, dydx);
    yt += y;
  }

  printf("%d %g \n",N,tictoc()/max_iter);
  if constexpr (N<max_N)
  {
    benchmark<N+1,max_N>();
  }
}

int main() 
{
  benchmark<1,30>();
  return 0;
}

And it's performance seems excellent in comparison:

1 4.21047e-07 
2 7.54833e-07 
3 1.36113e-06 
4 2.2819e-06 
5 3.79109e-06 
6 5.584e-06 
7 6.61087e-06 
8 8.51417e-06 
9 8.89301e-06 
10 8.73613e-06 
11 8.68988e-06 
12 8.71086e-06 
13 8.80289e-06 
14 8.87012e-06 
15 8.77094e-06 
16 9.22322e-06 
17 1.0052e-05 
18 1.10841e-05 
19 1.1492e-05 
20 1.2831e-05 
21 1.401e-05 
22 1.49281e-05 
23 1.6151e-05 
24 1.73478e-05 
25 1.82691e-05 
26 1.96209e-05 
27 2.08769e-05 
28 2.22659e-05 
29 2.35419e-05 
30 2.5017e-05 
alecjacobson commented 6 months ago

Following the tape-based implementation on https://rufflewind.com/2016-12-30/reverse-mode-automatic-differentiation, I was able to get performance on this test case that matches the stan-math trend (though it's still about 2× slower):

1 9.21e-07
2 2.755e-06
3 5.014e-06
4 6.388e-06
5 7.284e-06
6 9.185e-06
7 5.28e-06
8 9.678e-06
9 1.1366e-05
10 1.2893e-05
11 1.3666e-05
12 1.3365e-05
13 1.3695e-05
14 1.5345e-05
15 1.7023e-05
16 2.0021e-05