Open MilesCranmer opened 1 year ago
Also, I'm not sure if it's better to implement frule
and rrule
for eval_tree_array
, or for individual deg1_eval/deg2_eval
. Maybe the latter makes it easier to get higher-order derivatives?
I think if we do rrule
at the eval_tree_array
level, it might actually be more expensive because we'd have to do a matrix inverse on the frule
... So for true reverse derivatives we'd probably want to implement it at the deg1_eval
level.
Actually Enzyme seems to be doing okay! Much less maintenance required and is very likely the faster option. See https://github.com/EnzymeAD/Enzyme.jl/issues/810 for updates.
Apparently the proper way to build in differentiability is to define a rule for ChainRulesCore.jl: https://github.com/JuliaDiff/ChainRulesCore.jl. Specifically we would define an
frule
(forward) andrrule
(reverse). Then, evaluations inside DynamicExpressions.jl would be able to be link in a "chain" in a larger AD pipeline. So it might be easier to do a lot of other stuff.Maybe if we do this it would be relatively easy to get even higher order gradients in DynamicExpressions.jl?
@kazewong check this out