Rookfighter / autodiff-cpp

A single header-only C++ library for automatic / algorithmic differentiation.
MIT License
12 stars 2 forks source link

Exponential explosion of runtime for reverse mode #3

Open alecjacobson opened 5 months ago

alecjacobson commented 5 months ago

Thank you for sharing this code base. This code really helped me understand a possible organization for reverse-mode autodiff in C++.

The bwd::Number type appears to be hitting the same performance issue as autodiff.github.io's reverse mode: https://github.com/autodiff/autodiff/issues/332

Here's the smallest reproducible example I could create.

#include "adcpp/adcpp.hpp"
#include <vector>

// tictoc
#include <chrono>
double tictoc()
{
  double t = std::chrono::duration<double>(
      std::chrono::system_clock::now().time_since_epoch()).count();
  static double t0 = t;
  double tdiff = t-t0;
  t0 = t;
  return tdiff;
}

template <typename T, int N>
T simple_func()
{
  // O(n²)
  std::vector<T> A(N*N,0);
  // O(n³)
  for(int i = 0; i < N; i++)
  {
    for(int k = 0; k < N; k++)
    {
      for(int j = 0; j < N; j++)
      {
        A[k*N+j] = A[k*N+i]*A[i*N+j];
      }
    }
  }
  // O(n²)
  T y = 0;
  for(int i = 0; i < N; i++)
    for(int j = 0; j < N; j++)
      y += A[i*N+j];
  return y;
}

template <int N, int max_N>
void benchmark()
{
  tictoc();
  const int max_iter = 10/N;
  for(int iter = 0; iter < max_iter; iter++)
  {
    adcpp::bwd::Double y = simple_func<adcpp::bwd::Double,N>();
    adcpp::bwd::Double::DerivativeMap derivative;
    y.derivative(derivative);
  }
  printf("%d %g \n",N,tictoc()/max_iter);
  if constexpr (N<max_N)
  {
    benchmark<N+1,max_N>();
  }
}

int main()
{
  benchmark<1,10>();
  return 0;
}

The forward pass is O(n³) and outputs a scalar, but the runtime of taking the derivative appears to be O(exp(n)).

It seems that Expression::derivative or its derived classes' derivative is being called an exponential number of times.

Perhaps this is related to the common pitfall of implementing fibonacci numbers or factorials using recursion. Just a guess!