The Stan Math Library is a C++ template library for automatic differentiation of any order using forward, reverse, and mixed modes. It includes a range of built-in functions for probabilistic modeling, linear algebra, and equation solving.
I'm sure Tadej and others have thought of this already but I wanted to write down a neat idea before I forget. I was working on some slides to describe autodiff stuff and realized that if we use reverse_pass_callback() everywhere we should be able to separate out the callback's in var_stack from the memory we allocate for vari's. This will save some mem, help the compiler with devirtualization, and I think overall give a much cleaner pattern
So we just have to have callback_vari look something like
and in our ChainableStack we replace the vari_stack_ vector with a callback_stack_ like
std::vector<callback_base*> callback_stack_;
and then grad also get's updated (nested works fine I was just cutting it out the the example I'm writing)
static void grad() {
size_t end = autodiff_stacks.callback_stack_.size();
for (size_t i = end; i-- > 0;) {
autodiff_stacks.callback_stack_[i]->chain();
}
}
This is nice because it makes a vari go from 24 bytes to 16 bytes since we get rid of the pointer to the vtable for chain() (and move it over to the callback_base. So things like addition become both cleaner and leaner like
inline var operator+(double a, const var& b) {
if (a == 0.0) {
return a;
}
var ret(a + b.val());
reverse_pass_callback([a, b, ret](){
b.adj() += ret.adj();
});
return ret;
}
Description
I'm sure Tadej and others have thought of this already but I wanted to write down a neat idea before I forget. I was working on some slides to describe autodiff stuff and realized that if we use
reverse_pass_callback()
everywhere we should be able to separate out the callback's invar_stack
from the memory we allocate for vari's. This will save some mem, help the compiler with devirtualization, and I think overall give a much cleaner patternSo we just have to have
callback_vari
look something likeand in our ChainableStack we replace the
vari_stack_
vector with acallback_stack_
likeand then grad also get's updated (nested works fine I was just cutting it out the the example I'm writing)
This is nice because it makes a vari go from 24 bytes to 16 bytes since we get rid of the pointer to the vtable for
chain()
(and move it over to thecallback_base
. So things like addition become both cleaner and leaner likefull example here with addition
https://godbolt.org/z/chGdcY
Current Version:
v3.3.0