SymbolicML / DynamicExpressions.jl

Ridiculously fast symbolic expressions
https://symbolicml.org/DynamicExpressions.jl/dev
Apache License 2.0
92 stars 12 forks source link

[Feature] Interface with ChainRulesCore.jl #29

Open MilesCranmer opened 1 year ago

MilesCranmer commented 1 year ago

Apparently the proper way to build in differentiability is to define a rule for ChainRulesCore.jl: https://github.com/JuliaDiff/ChainRulesCore.jl. Specifically we would define an frule (forward) and rrule (reverse). Then, evaluations inside DynamicExpressions.jl would be able to be link in a "chain" in a larger AD pipeline. So it might be easier to do a lot of other stuff.

Maybe if we do this it would be relatively easy to get even higher order gradients in DynamicExpressions.jl?

@kazewong check this out

MilesCranmer commented 1 year ago

Also, I'm not sure if it's better to implement frule and rrule for eval_tree_array, or for individual deg1_eval/deg2_eval. Maybe the latter makes it easier to get higher-order derivatives?

I think if we do rrule at the eval_tree_array level, it might actually be more expensive because we'd have to do a matrix inverse on the frule... So for true reverse derivatives we'd probably want to implement it at the deg1_eval level.

MilesCranmer commented 1 year ago

Actually Enzyme seems to be doing okay! Much less maintenance required and is very likely the faster option. See https://github.com/EnzymeAD/Enzyme.jl/issues/810 for updates.