This is an umbrella task that we will split into smaller tasks upon further discussion.
Reverse-mode automatic-differentation creates new 'context' functions that run the original function while recording necessary intermediate values on an "intermediate context struct", which is used as an input to the function that actually does the derivative backpropagation. The size of this context tends to be the most critical factor that affects the performance of the resulting derivative kernel as it translates to per-thread memory overhead (and resulting memory traffic).
Deciding which intermediate values go in this context and which are recomputed later is a key open problem in automatic differentiation often referred to as the "checkpointing" problem. In Slang, we currently do not attempt to do anything clever: we have a fixed heuristic and allow the user to control some of the choices via [PreferRecompute] and [PreferCheckpoint] decorations on function declarations. At the moment the resulting decisions made by the compiler (and their impact on the final memory usage) are opaque. The only way to do this is to open up the generated code and try to cross reference the storage instructions with the source code.
The first step to making Slang auto-diff more accessible is to provide feedback to the user. This issue proposes that we do this through an API endpoint (and corresponding CLI switch) that produces memory-footprint information on a per-instruction basis (i.e. did this instruction get stored? and if so, how many bytes does it take up in the context struct of the outer-most differentiated function.
This is an umbrella task that we will split into smaller tasks upon further discussion.
Reverse-mode automatic-differentation creates new 'context' functions that run the original function while recording necessary intermediate values on an "intermediate context struct", which is used as an input to the function that actually does the derivative backpropagation. The size of this context tends to be the most critical factor that affects the performance of the resulting derivative kernel as it translates to per-thread memory overhead (and resulting memory traffic).
Deciding which intermediate values go in this context and which are recomputed later is a key open problem in automatic differentiation often referred to as the "checkpointing" problem. In Slang, we currently do not attempt to do anything clever: we have a fixed heuristic and allow the user to control some of the choices via
[PreferRecompute]
and[PreferCheckpoint]
decorations on function declarations. At the moment the resulting decisions made by the compiler (and their impact on the final memory usage) are opaque. The only way to do this is to open up the generated code and try to cross reference the storage instructions with the source code.The first step to making Slang auto-diff more accessible is to provide feedback to the user. This issue proposes that we do this through an API endpoint (and corresponding CLI switch) that produces memory-footprint information on a per-instruction basis (i.e. did this instruction get stored? and if so, how many bytes does it take up in the context struct of the outer-most differentiated function.