Closed MilesCranmer closed 1 month ago
Totals | |
---|---|
Change from base Build 8870212630: | 0.08% |
Covered Lines: | 1637 |
Relevant Lines: | 1727 |
master | 749385a1183b4e... | master/749385a1183b4e... | |
---|---|---|---|
eval/ComplexF32/evaluation | 7.45 ± 0.48 ms | 7.44 ± 0.46 ms | 1 |
eval/ComplexF64/evaluation | 9.78 ± 0.66 ms | 9.67 ± 0.72 ms | 1.01 |
eval/Float32/derivative | 11.1 ± 1.9 ms | 10.9 ± 1.9 ms | 1.01 |
eval/Float32/derivative_turbo | 11 ± 1.9 ms | 11 ± 2 ms | 1 |
eval/Float32/evaluation | 2.74 ± 0.22 ms | 2.72 ± 0.23 ms | 1.01 |
eval/Float32/evaluation_bumper | 0.579 ± 0.013 ms | 0.558 ± 0.013 ms | 1.04 |
eval/Float32/evaluation_turbo | 0.711 ± 0.034 ms | 0.716 ± 0.03 ms | 0.993 |
eval/Float32/evaluation_turbo_bumper | 0.577 ± 0.013 ms | 0.557 ± 0.012 ms | 1.03 |
eval/Float64/derivative | 14.4 ± 0.83 ms | 14.4 ± 0.75 ms | 1 |
eval/Float64/derivative_turbo | 14.7 ± 1 ms | 14.8 ± 1 ms | 0.998 |
eval/Float64/evaluation | 2.93 ± 0.23 ms | 2.91 ± 0.25 ms | 1.01 |
eval/Float64/evaluation_bumper | 1.29 ± 0.044 ms | 1.2 ± 0.045 ms | 1.07 |
eval/Float64/evaluation_turbo | 1.21 ± 0.068 ms | 1.21 ± 0.062 ms | 1 |
eval/Float64/evaluation_turbo_bumper | 1.28 ± 0.042 ms | 1.2 ± 0.041 ms | 1.07 |
utils/combine_operators/break_sharing | 0.0408 ± 0.0013 ms | 0.0416 ± 0.0013 ms | 0.981 |
utils/convert/break_sharing | 28.5 ± 1 μs | 29 ± 1.2 μs | 0.985 |
utils/convert/preserve_sharing | 0.131 ± 0.0038 ms | 0.131 ± 0.0035 ms | 0.996 |
utils/copy/break_sharing | 29.7 ± 1 μs | 29.3 ± 1.2 μs | 1.01 |
utils/copy/preserve_sharing | 0.131 ± 0.0035 ms | 0.129 ± 0.0035 ms | 1.02 |
utils/count_constants/break_sharing | 10.6 ± 0.16 μs | 10.9 ± 0.21 μs | 0.978 |
utils/count_constants/preserve_sharing | 0.112 ± 0.0028 ms | 0.113 ± 0.0028 ms | 0.994 |
utils/count_depth/break_sharing | 17.3 ± 0.41 μs | 18.5 ± 0.44 μs | 0.936 |
utils/count_nodes/break_sharing | 9.95 ± 0.18 μs | 10.1 ± 0.19 μs | 0.987 |
utils/count_nodes/preserve_sharing | 0.116 ± 0.0032 ms | 0.116 ± 0.0027 ms | 1.01 |
utils/get_set_constants!/break_sharing | 0.0526 ± 0.00086 ms | 0.0537 ± 0.00093 ms | 0.98 |
utils/get_set_constants!/preserve_sharing | 0.329 ± 0.0089 ms | 0.328 ± 0.0071 ms | 1 |
utils/has_constants/break_sharing | 4.37 ± 0.22 μs | 4.6 ± 0.22 μs | 0.949 |
utils/has_operators/break_sharing | 1.78 ± 0.027 μs | 1.93 ± 0.024 μs | 0.923 |
utils/hash/break_sharing | 29.9 ± 0.44 μs | 30.2 ± 0.48 μs | 0.99 |
utils/hash/preserve_sharing | 0.134 ± 0.0033 ms | 0.134 ± 0.0032 ms | 0.999 |
utils/index_constants/break_sharing | 27.9 ± 0.83 μs | 27.8 ± 0.76 μs | 1 |
utils/index_constants/preserve_sharing | 0.131 ± 0.0047 ms | 0.132 ± 0.0051 ms | 0.989 |
utils/is_constant/break_sharing | 4.78 ± 0.23 μs | 4.36 ± 0.23 μs | 1.1 |
utils/simplify_tree/break_sharing | 0.176 ± 0.016 ms | 0.18 ± 0.015 ms | 0.982 |
utils/simplify_tree/preserve_sharing | 0.321 ± 0.017 ms | 0.323 ± 0.018 ms | 0.994 |
utils/string_tree/break_sharing | 0.531 ± 0.018 ms | 0.527 ± 0.016 ms | 1.01 |
utils/string_tree/preserve_sharing | 0.686 ± 0.027 ms | 0.672 ± 0.023 ms | 1.02 |
time_to_load | 0.2 ± 0.0098 s | 0.211 ± 0.0023 s | 0.949 |
@avik-pal perhaps this is useful for the Lux.jl extension? (Would love to hear what you think of this PR, btw, given your experience in this area)
This looks great, I think I will be able to remove some of the custom handling I had in Lux for this
Fantastic. Thanks for looking!
Do you plan to capture ForwardDiff calls as well? I was unsure how to capture them at the Node constants level, for Lux I handled them at the parameters level https://github.com/LuxDL/Lux.jl/blob/main/ext/LuxDynamicExpressionsForwardDiffExt.jl#L8-L52
Do you plan to capture ForwardDiff calls as well?
I would be very happy to have ForwardDiff support for tree constants. For my own use-cases its lower on the priority list, so not sure when I'll get to it. The rrule
is so far a priority for me as I want to have some Zygote-based AD optimization in SymbolicRegression.jl (right now its still finite difference-based – which surprisingly hasn't been so bad given it's low-dimensional, but can get a bit slow for very complex expressions).
I was unsure how to capture them at the Node constants level, for Lux I handled them at the parameters level https://github.com/LuxDL/Lux.jl/blob/main/ext/LuxDynamicExpressionsForwardDiffExt.jl#L8-L52
Nice! I'm not sure how to translate this but let me know if you'd be open to moving it over here. Not sure how much work it would be though.
capture them at the Node constants level
In the Optim.optimize
what I will do is store a vector of Ref
to the constant nodes, and just update them via dereferencing. (Not sure if this is what you were asking).
Then I can update all the parameters by
The nice part about this is that it also works for GraphNode
where you have multiple parents pointing to the same child – the filter_map
will only return a single Ref
to the child node, so you don't end up optimizing the same parameter from two elements.
Yes, I am definitely open to moving them here. What I meant with capturing them is how to define the dispatch. For eg, in Lux since I keep the parameters extracted in a vector so it is simple enough to write ::AbstractVector{<:Dual}
. I am not sure how to detect ForwardDiff Duals "nicely" when they are part of the Nodes.
It is possible to do it here https://github.com/SymbolicML/DynamicExpressions.jl/blob/27b619951d6faee18628fe7da427161dee755a7f/ext/DynamicExpressionsOptimExt.jl#L38-L45 I think, because that code won't natively work with ForwardDiff.
It is also possible that ForwardDiff might be efficient enough without this special handling, given that you mention FiniteDifferences is already fast.
I see, thanks. Seems a bit trickier. Will think more...
Implements
ChainRulesCore.rrule
foreval_tree_array
for thetree
and theX
argument.For the
tree
argument I had to implement something custom becauseChainRulesCore.Tangent
doesn't support recursive types. To get around this I implementwhere
gradient
is a vector gradient of the constants in the tree in the usual depth-first order. It has some of theAbstractTangent
interface implemented (as much as makes sense).However this probably requires some care in downstream uses because it's not an array.
@avik-pal perhaps this is useful for the Lux.jl extension? (Would love to hear what you think of this PR, btw, given your experience in this area)
TODO:
Optim.optimize
extension to use Zygote AD with this interface. Or at least be compatible with user-passed gradients that return aNodeTangent
.