SymbolicML / DynamicExpressions.jl

Ridiculously fast symbolic expressions
https://symbolicml.org/DynamicExpressions.jl/dev
Apache License 2.0
90 stars 11 forks source link

Add ChainRules support #71

Closed MilesCranmer closed 1 month ago

MilesCranmer commented 1 month ago

Implements ChainRulesCore.rrule for eval_tree_array for the tree and the X argument.

For the tree argument I had to implement something custom because ChainRulesCore.Tangent doesn't support recursive types. To get around this I implement

struct NodeTangent{T,N<:AbstractExpressionNode{T},A<:AbstractArray{T}} <: AbstractTangent
    tree::N
    gradient::A
end

where gradient is a vector gradient of the constants in the tree in the usual depth-first order. It has some of the AbstractTangent interface implemented (as much as makes sense).

However this probably requires some care in downstream uses because it's not an array.

@avik-pal perhaps this is useful for the Lux.jl extension? (Would love to hear what you think of this PR, btw, given your experience in this area)


TODO:

coveralls commented 1 month ago

Pull Request Test Coverage Report for Build 8870522033

Details


Totals Coverage Status
Change from base Build 8870212630: 0.08%
Covered Lines: 1637
Relevant Lines: 1727

💛 - Coveralls
github-actions[bot] commented 1 month ago

Benchmark Results

master 749385a1183b4e... master/749385a1183b4e...
eval/ComplexF32/evaluation 7.45 ± 0.48 ms 7.44 ± 0.46 ms 1
eval/ComplexF64/evaluation 9.78 ± 0.66 ms 9.67 ± 0.72 ms 1.01
eval/Float32/derivative 11.1 ± 1.9 ms 10.9 ± 1.9 ms 1.01
eval/Float32/derivative_turbo 11 ± 1.9 ms 11 ± 2 ms 1
eval/Float32/evaluation 2.74 ± 0.22 ms 2.72 ± 0.23 ms 1.01
eval/Float32/evaluation_bumper 0.579 ± 0.013 ms 0.558 ± 0.013 ms 1.04
eval/Float32/evaluation_turbo 0.711 ± 0.034 ms 0.716 ± 0.03 ms 0.993
eval/Float32/evaluation_turbo_bumper 0.577 ± 0.013 ms 0.557 ± 0.012 ms 1.03
eval/Float64/derivative 14.4 ± 0.83 ms 14.4 ± 0.75 ms 1
eval/Float64/derivative_turbo 14.7 ± 1 ms 14.8 ± 1 ms 0.998
eval/Float64/evaluation 2.93 ± 0.23 ms 2.91 ± 0.25 ms 1.01
eval/Float64/evaluation_bumper 1.29 ± 0.044 ms 1.2 ± 0.045 ms 1.07
eval/Float64/evaluation_turbo 1.21 ± 0.068 ms 1.21 ± 0.062 ms 1
eval/Float64/evaluation_turbo_bumper 1.28 ± 0.042 ms 1.2 ± 0.041 ms 1.07
utils/combine_operators/break_sharing 0.0408 ± 0.0013 ms 0.0416 ± 0.0013 ms 0.981
utils/convert/break_sharing 28.5 ± 1 μs 29 ± 1.2 μs 0.985
utils/convert/preserve_sharing 0.131 ± 0.0038 ms 0.131 ± 0.0035 ms 0.996
utils/copy/break_sharing 29.7 ± 1 μs 29.3 ± 1.2 μs 1.01
utils/copy/preserve_sharing 0.131 ± 0.0035 ms 0.129 ± 0.0035 ms 1.02
utils/count_constants/break_sharing 10.6 ± 0.16 μs 10.9 ± 0.21 μs 0.978
utils/count_constants/preserve_sharing 0.112 ± 0.0028 ms 0.113 ± 0.0028 ms 0.994
utils/count_depth/break_sharing 17.3 ± 0.41 μs 18.5 ± 0.44 μs 0.936
utils/count_nodes/break_sharing 9.95 ± 0.18 μs 10.1 ± 0.19 μs 0.987
utils/count_nodes/preserve_sharing 0.116 ± 0.0032 ms 0.116 ± 0.0027 ms 1.01
utils/get_set_constants!/break_sharing 0.0526 ± 0.00086 ms 0.0537 ± 0.00093 ms 0.98
utils/get_set_constants!/preserve_sharing 0.329 ± 0.0089 ms 0.328 ± 0.0071 ms 1
utils/has_constants/break_sharing 4.37 ± 0.22 μs 4.6 ± 0.22 μs 0.949
utils/has_operators/break_sharing 1.78 ± 0.027 μs 1.93 ± 0.024 μs 0.923
utils/hash/break_sharing 29.9 ± 0.44 μs 30.2 ± 0.48 μs 0.99
utils/hash/preserve_sharing 0.134 ± 0.0033 ms 0.134 ± 0.0032 ms 0.999
utils/index_constants/break_sharing 27.9 ± 0.83 μs 27.8 ± 0.76 μs 1
utils/index_constants/preserve_sharing 0.131 ± 0.0047 ms 0.132 ± 0.0051 ms 0.989
utils/is_constant/break_sharing 4.78 ± 0.23 μs 4.36 ± 0.23 μs 1.1
utils/simplify_tree/break_sharing 0.176 ± 0.016 ms 0.18 ± 0.015 ms 0.982
utils/simplify_tree/preserve_sharing 0.321 ± 0.017 ms 0.323 ± 0.018 ms 0.994
utils/string_tree/break_sharing 0.531 ± 0.018 ms 0.527 ± 0.016 ms 1.01
utils/string_tree/preserve_sharing 0.686 ± 0.027 ms 0.672 ± 0.023 ms 1.02
time_to_load 0.2 ± 0.0098 s 0.211 ± 0.0023 s 0.949
avik-pal commented 1 month ago

@avik-pal perhaps this is useful for the Lux.jl extension? (Would love to hear what you think of this PR, btw, given your experience in this area)

This looks great, I think I will be able to remove some of the custom handling I had in Lux for this

MilesCranmer commented 1 month ago

Fantastic. Thanks for looking!

avik-pal commented 1 month ago

Do you plan to capture ForwardDiff calls as well? I was unsure how to capture them at the Node constants level, for Lux I handled them at the parameters level https://github.com/LuxDL/Lux.jl/blob/main/ext/LuxDynamicExpressionsForwardDiffExt.jl#L8-L52

MilesCranmer commented 1 month ago

Do you plan to capture ForwardDiff calls as well?

I would be very happy to have ForwardDiff support for tree constants. For my own use-cases its lower on the priority list, so not sure when I'll get to it. The rrule is so far a priority for me as I want to have some Zygote-based AD optimization in SymbolicRegression.jl (right now its still finite difference-based – which surprisingly hasn't been so bad given it's low-dimensional, but can get a bit slow for very complex expressions).

I was unsure how to capture them at the Node constants level, for Lux I handled them at the parameters level https://github.com/LuxDL/Lux.jl/blob/main/ext/LuxDynamicExpressionsForwardDiffExt.jl#L8-L52

Nice! I'm not sure how to translate this but let me know if you'd be open to moving it over here. Not sure how much work it would be though.

capture them at the Node constants level

In the Optim.optimize what I will do is store a vector of Ref to the constant nodes, and just update them via dereferencing. (Not sure if this is what you were asking).

https://github.com/SymbolicML/DynamicExpressions.jl/blob/27b619951d6faee18628fe7da427161dee755a7f/ext/DynamicExpressionsOptimExt.jl#L91-L94

Then I can update all the parameters by

https://github.com/SymbolicML/DynamicExpressions.jl/blob/27b619951d6faee18628fe7da427161dee755a7f/ext/DynamicExpressionsOptimExt.jl#L114-L117

The nice part about this is that it also works for GraphNode where you have multiple parents pointing to the same child – the filter_map will only return a single Ref to the child node, so you don't end up optimizing the same parameter from two elements.

avik-pal commented 1 month ago

Yes, I am definitely open to moving them here. What I meant with capturing them is how to define the dispatch. For eg, in Lux since I keep the parameters extracted in a vector so it is simple enough to write ::AbstractVector{<:Dual}. I am not sure how to detect ForwardDiff Duals "nicely" when they are part of the Nodes.

It is possible to do it here https://github.com/SymbolicML/DynamicExpressions.jl/blob/27b619951d6faee18628fe7da427161dee755a7f/ext/DynamicExpressionsOptimExt.jl#L38-L45 I think, because that code won't natively work with ForwardDiff.

It is also possible that ForwardDiff might be efficient enough without this special handling, given that you mention FiniteDifferences is already fast.

MilesCranmer commented 1 month ago

I see, thanks. Seems a bit trickier. Will think more...