stan-dev / math

The Stan Math Library is a C++ template library for automatic differentiation of any order using forward, reverse, and mixed modes. It includes a range of built-in functions for probabilistic modeling, linear algebra, and equation solving.
https://mc-stan.org
BSD 3-Clause "New" or "Revised" License
735 stars 185 forks source link

Seperating out callbacks from vari #2111

Open SteveBronder opened 3 years ago

SteveBronder commented 3 years ago

Description

I'm sure Tadej and others have thought of this already but I wanted to write down a neat idea before I forget. I was working on some slides to describe autodiff stuff and realized that if we use reverse_pass_callback() everywhere we should be able to separate out the callback's in var_stack from the memory we allocate for vari's. This will save some mem, help the compiler with devirtualization, and I think overall give a much cleaner pattern

So we just have to have callback_vari look something like

struct callback_base {
  virtual void chain() = 0;
};

template <typename F>
struct callback_vari final : public callback_base {
  F rev_functor_;

  explicit callback_vari(F&& rev_functor)
      : rev_functor_(std::forward<F>(rev_functor)) {
    autodiff_stacks.callback_stack_.push_back(this);
  }
  inline void chain() { rev_functor_(); }
};

and in our ChainableStack we replace the vari_stack_ vector with a callback_stack_ like

    std::vector<callback_base*> callback_stack_;

and then grad also get's updated (nested works fine I was just cutting it out the the example I'm writing)

static void grad() {
  size_t end = autodiff_stacks.callback_stack_.size();
  for (size_t i = end; i-- > 0;) {
    autodiff_stacks.callback_stack_[i]->chain();
  }
}

This is nice because it makes a vari go from 24 bytes to 16 bytes since we get rid of the pointer to the vtable for chain() (and move it over to the callback_base. So things like addition become both cleaner and leaner like

inline var operator+(double a, const var& b) {
  if (a == 0.0) {
    return a;
  }
  var ret(a + b.val());
  reverse_pass_callback([a, b, ret](){
    b.adj() += ret.adj();
  });
  return ret;
}

full example here with addition

https://godbolt.org/z/chGdcY

Current Version:

v3.3.0

bbbales2 commented 3 years ago

This is nice because it makes a vari go from 24 bytes to 16 bytes

That is really nice.

So vari s would longer have chains and are just storage for vals and adjoints?

SteveBronder commented 3 years ago

Yep!