SymbolicML / DynamicExpressions.jl

Ridiculously fast symbolic expressions
https://symbolicml.org/DynamicExpressions.jl/dev
Apache License 2.0
103 stars 15 forks source link

Overload `Optim.optimize` for `::Node` #30

Closed MilesCranmer closed 8 months ago

MilesCranmer commented 1 year ago

This depends on #27.

This implements a memory efficient Optim.optimize for ::Node set to the initial value. For example:

using Optim
using DynamicExpressions

operators = OperatorEnum(; binary_operators=[+, -, *, /], unary_operators=[cos, sin])

# Problem:
x1, x2, x3 = ntuple(i->Node(;feature=i), 3)
tree = cos(x1 * 3.2 - 5.8) * x2 - 0.8*x3*x3 - 0.2
truth(x) = cos(x[1] * 3.1 - 5.9) * x[2] - 2*x[3]^2 - 0.15

# Dataset:
X = randn(3, 32)
y = [truth(x) for x in eachcol(X)]

# Optimize!
optimize(tree -> sum(abs2, tree(X, operators) .- y), tree)

This gives us:

julia> tree
(((cos((x1 * 3.0999948526250973) - 5.900012476309413) * x2) - ((1.9999972274990503 * x3) * x3)) - 0.14998888911815988)

and is quite fast:

julia> @btime optimize(t -> sum(abs2, t(X) .- y), copy_node(tree))
  537.312 μs (4197 allocations: 357.98 KiB)
 * Status: success

 * Candidate solution
    Final objective value:     6.346103e-09

 * Found with
    Algorithm:     Nelder-Mead

 * Convergence measures
    √(Σ(yᵢ-ȳ)²)/n ≤ 1.0e-08

 * Work counters
    Seconds run:   0  (vs limit Inf)
    Iterations:    150
    f(x) calls:    263

(recall that this tree's structure is completely variable)

@AlCap23 @ChrisRackauckas maybe this is useful for your stuff too?

github-actions[bot] commented 1 year ago

Pull Request Test Coverage Report for Build 4915532924


Changes Missing Coverage Covered Lines Changed/Added Lines %
src/base.jl 78 90 86.67%
src/ConstantOptimization.jl 0 28 0.0%
<!-- Total: 100 140 71.43% -->
Files with Coverage Reduction New Missed Lines %
src/EvaluateEquation.jl 1 99.64%
src/Utils.jl 1 71.11%
src/InterfaceSymbolicUtils.jl 2 66.67%
src/precompile.jl 2 80.8%
src/SimplifyEquation.jl 2 86.42%
<!-- Total: 8 -->
Totals Coverage Status
Change from base Build 4824867612: -3.4%
Covered Lines: 963
Relevant Lines: 1128

💛 - Coveralls
github-actions[bot] commented 1 year ago

Benchmark Results

master 80c370fe8c11ba... t[master]/t[80c370fe8c11ba...]
eval/ComplexF32/evaluation 7.27 ± 0.45 ms 7.22 ± 0.53 ms 1.01
eval/ComplexF64/evaluation 9.48 ± 0.7 ms 9.47 ± 0.69 ms 1
eval/Float32/derivative 10.6 ± 1.3 ms 10.5 ± 1.3 ms 1.01
eval/Float32/derivative_turbo 10.6 ± 1.2 ms 10.6 ± 1.3 ms 0.996
eval/Float32/evaluation 2.54 ± 0.23 ms 2.57 ± 0.23 ms 0.986
eval/Float32/evaluation_turbo 0.541 ± 0.022 ms 0.54 ± 0.022 ms 1
eval/Float64/derivative 13.3 ± 0.62 ms 13.3 ± 0.6 ms 0.994
eval/Float64/derivative_turbo 13.3 ± 0.6 ms 13.4 ± 0.54 ms 0.997
eval/Float64/evaluation 2.71 ± 0.25 ms 2.72 ± 0.23 ms 0.997
eval/Float64/evaluation_turbo 1.02 ± 0.059 ms 1.02 ± 0.059 ms 0.996
utils/combine_operators/break_sharing 0.0493 ± 0.0031 ms 0.0498 ± 0.0029 ms 0.991
utils/convert/break_sharing 27.8 ± 0.92 μs 27.6 ± 0.92 μs 1.01
utils/convert/preserve_sharing 0.126 ± 0.0024 ms 0.125 ± 0.0021 ms 1.01
utils/copy/break_sharing 28.6 ± 0.86 μs 28.1 ± 0.95 μs 1.02
utils/copy/preserve_sharing 0.126 ± 0.0023 ms 0.125 ± 0.0022 ms 1.01
utils/count_constants/break_sharing 10.3 ± 0.16 μs 10.5 ± 0.14 μs 0.985
utils/count_constants/preserve_sharing 0.111 ± 0.0021 ms 0.111 ± 0.0019 ms 0.998
utils/count_depth/break_sharing 17.2 ± 0.37 μs 12.7 ± 0.2 μs 1.35
utils/count_nodes/break_sharing 10.1 ± 0.15 μs 10.1 ± 0.15 μs 0.998
utils/count_nodes/preserve_sharing 0.113 ± 0.002 ms 0.114 ± 0.0019 ms 0.994
utils/get_set_constants!/break_sharing 0.0528 ± 0.00087 ms 0.053 ± 0.00081 ms 0.997
utils/get_set_constants!/preserve_sharing 0.316 ± 0.005 ms 0.316 ± 0.0048 ms 0.998
utils/has_constants/break_sharing 4.37 ± 0.23 μs 4.27 ± 0.22 μs 1.02
utils/has_operators/break_sharing 1.77 ± 0.02 μs 1.76 ± 0.02 μs 1.01
utils/hash/break_sharing 29.8 ± 0.45 μs 29.9 ± 0.44 μs 0.996
utils/hash/preserve_sharing 0.13 ± 0.0023 ms 0.131 ± 0.0042 ms 0.989
utils/index_constants/break_sharing 27.4 ± 0.68 μs 27.2 ± 0.58 μs 1.01
utils/index_constants/preserve_sharing 0.126 ± 0.0022 ms 0.125 ± 0.0022 ms 1.01
utils/is_constant/break_sharing 4.75 ± 0.22 μs 4.76 ± 0.22 μs 0.998
utils/simplify_tree/break_sharing 0.17 ± 0.015 ms 0.175 ± 0.015 ms 0.971
utils/simplify_tree/preserve_sharing 0.304 ± 0.018 ms 0.293 ± 0.017 ms 1.04
utils/string_tree/break_sharing 0.491 ± 0.0098 ms 0.491 ± 0.0091 ms 1
utils/string_tree/preserve_sharing 0.631 ± 0.011 ms 0.627 ± 0.012 ms 1.01
time_to_load 0.643 ± 0.0032 s 0.645 ± 0.0093 s 0.996
ChrisRackauckas commented 1 year ago

I don't think I've needed this but I can see how it could be useful.

MilesCranmer commented 1 year ago

The Enzyme interface would make this nicer: https://github.com/MilesCranmer/SymbolicRegression.jl/pull/254