SymbolicML / DynamicExpressions.jl

Ridiculously fast symbolic expressions
https://symbolicml.org/DynamicExpressions.jl/dev
Apache License 2.0
90 stars 11 forks source link

Enzyme compatibility #52

Closed MilesCranmer closed 5 months ago

MilesCranmer commented 10 months ago

This implements several changes to get Enzyme compatibility working (in support of in https://github.com/MilesCranmer/SymbolicRegression.jl/pull/254)

  1. Makes OperatorEnum based on tuples of functions rather than vectors. Thus, it once again is specialized to the operators, rather than a generic struct.
  2. Removes all remaining type inference from the evaluation kernels, by using a generated if statement over the operators, rather than indexing a vector of functions.
  3. Makes the internal evaluation kernels return a struct storing the result and a flag. Before, each evaluation kernel returned a tuple. This hurt type inference for some reason (maybe because (a, b) = (1, 2, 3) is valid Julia code?)
  4. Ensures that the constant tree evaluation returns mutable storage, as otherwise that branch seemed to mess with Enzyme.jl's gradients
  5. Allow Val(false) to be passed for turbo in the evaluation codes (rather than just false), so that the compiler can completely remove the LoopVectorization branches inside an Enzyme gradient call.

There are a few other unrelated changes which made sense to implement simultaneously:

  1. Add tests for type stability
  2. Deprecate enable_autodiff. Instead, we just use Zygote within each gradient kernel – it should compile in the differential operator anyways, so no need to store them (especially because we store operators in a tuple now).
  3. Fix some type inference issues with eval_grad_tree_array.
  4. Deprecates turbo for eval_grad_tree_array, as it does not seem to improve performance through Zygote anyways.

TODO:

coveralls commented 10 months ago

Pull Request Test Coverage Report for Build 7517350397


Totals Coverage Status
Change from base Build 7515102605: -0.06%
Covered Lines: 1368
Relevant Lines: 1452

💛 - Coveralls
github-actions[bot] commented 10 months ago

Benchmark Results

master ecb3574c93d373... t[master]/t[ecb3574c93d373...]
eval/ComplexF32/evaluation 7.39 ± 0.49 ms 7.26 ± 0.47 ms 1.02
eval/ComplexF64/evaluation 9.55 ± 0.72 ms 9.49 ± 0.7 ms 1.01
eval/Float32/derivative 10.8 ± 1.5 ms 10.8 ± 1.5 ms 0.998
eval/Float32/derivative_turbo 12.1 ± 1.5 ms 10.7 ± 1.4 ms 1.13
eval/Float32/evaluation 2.67 ± 0.22 ms 2.6 ± 0.22 ms 1.03
eval/Float32/evaluation_turbo 0.619 ± 0.029 ms 0.551 ± 0.027 ms 1.12
eval/Float64/derivative 13.9 ± 0.71 ms 13.7 ± 0.53 ms 1.01
eval/Float64/derivative_turbo 14.4 ± 0.61 ms 13.8 ± 0.6 ms 1.04
eval/Float64/evaluation 2.85 ± 0.24 ms 2.74 ± 0.24 ms 1.04
eval/Float64/evaluation_turbo 1.12 ± 0.061 ms 1.03 ± 0.058 ms 1.09
utils/combine_operators/break_sharing 0.0398 ± 0.0024 ms 0.0505 ± 0.0029 ms 0.789
utils/convert/break_sharing 28.1 ± 0.62 μs 27.8 ± 1 μs 1.01
utils/convert/preserve_sharing 0.127 ± 0.0029 ms 0.13 ± 0.0027 ms 0.977
utils/copy/break_sharing 28.7 ± 0.66 μs 28.5 ± 1 μs 1.01
utils/copy/preserve_sharing 0.128 ± 0.0029 ms 0.131 ± 0.0026 ms 0.98
utils/count_constants/break_sharing 10.3 ± 0.16 μs 10.6 ± 0.16 μs 0.972
utils/count_constants/preserve_sharing 0.114 ± 0.0026 ms 0.111 ± 0.0025 ms 1.02
utils/count_depth/break_sharing 17 ± 0.41 μs 17.3 ± 0.38 μs 0.984
utils/count_nodes/break_sharing 9.83 ± 0.17 μs 10.2 ± 0.15 μs 0.967
utils/count_nodes/preserve_sharing 0.116 ± 0.0026 ms 0.115 ± 0.0027 ms 1.02
utils/get_set_constants!/break_sharing 0.0535 ± 0.00081 ms 0.0532 ± 0.00073 ms 1.01
utils/get_set_constants!/preserve_sharing 0.327 ± 0.0067 ms 0.322 ± 0.0063 ms 1.02
utils/has_constants/break_sharing 4.61 ± 0.21 μs 4.34 ± 0.22 μs 1.06
utils/has_operators/break_sharing 2.09 ± 0.018 μs 1.77 ± 0.03 μs 1.18
utils/hash/break_sharing 30.5 ± 0.48 μs 30 ± 0.45 μs 1.02
utils/hash/preserve_sharing 0.132 ± 0.0028 ms 0.132 ± 0.0025 ms 1
utils/index_constants/break_sharing 27.5 ± 0.67 μs 27.7 ± 0.77 μs 0.994
utils/index_constants/preserve_sharing 0.129 ± 0.0028 ms 0.127 ± 0.0026 ms 1.01
utils/is_constant/break_sharing 4.62 ± 0.23 μs 4.77 ± 0.21 μs 0.969
utils/simplify_tree/break_sharing 0.259 ± 0.021 ms 0.174 ± 0.016 ms 1.49
utils/simplify_tree/preserve_sharing 0.376 ± 0.022 ms 0.296 ± 0.017 ms 1.27
utils/string_tree/break_sharing 0.564 ± 0.015 ms 0.501 ± 0.014 ms 1.13
utils/string_tree/preserve_sharing 0.695 ± 0.02 ms 0.639 ± 0.017 ms 1.09
time_to_load 0.686 ± 0.0026 s 0.66 ± 0.0064 s 1.04
MilesCranmer commented 6 months ago

Removing from v0.14 milesone as there's some remaining issues with compilation times.

wsmoses commented 5 months ago

@MilesCranmer FYI the fix for that issue is released in Enzyme 0.11.12 if you want to try rerunning this with that.

MilesCranmer commented 5 months ago

Thanks! Trying now.

MilesCranmer commented 5 months ago

Awesome. Seems to work now! What was the thing that fixed it?

Edit: it looks like there are issues on all systems for Julia 1.6 maybe, as well as Windows? https://github.com/SymbolicML/DynamicExpressions.jl/actions/runs/7394870782/job/20116976319?pr=52#step:6:382 https://github.com/SymbolicML/DynamicExpressions.jl/actions/runs/7394870782/job/20116975868#step:6:413 – I can report these if you want

(expand) ```asm ; Function Attrs: willreturn define internal fastcc void @preprocess_julia_dispatch_deg1_eval_20893({ {} addrspace(10)*, i8 }* noalias nocapture sret %0, [1 x {} addrspace(10)*]* noalias nocapture writeonly "enzyme_inactive" "enzymejl_returnRoots" %1, {} addrspace(10)* nonnull align 8 dereferenceable(40) %2, {} addrspace(10)* nonnull align 16 dereferenceable(40) %3, i8 zeroext "enzyme_inactive" %4) unnamed_addr #76 !dbg !8892 { top: %5 = call {}*** @julia.ptls_states() %6 = alloca { {} addrspace(10)*, i8 }, align 8 %7 = alloca [1 x {} addrspace(10)*], align 8 %8 = alloca { {} addrspace(10)*, i8 }, align 8 %9 = alloca [1 x {} addrspace(10)*], align 8 %10 = alloca { {} addrspace(10)*, i8 }, align 8 %11 = alloca [1 x {} addrspace(10)*], align 8 %12 = alloca { {} addrspace(10)*, i8 }, align 8 %13 = alloca [1 x {} addrspace(10)*], align 8 %14 = alloca { {} addrspace(10)*, i8 }, align 8 %15 = alloca [1 x {} addrspace(10)*], align 8 %16 = alloca { {} addrspace(10)*, i8 }, align 8 %17 = alloca [1 x {} addrspace(10)*], align 8 %18 = alloca { {} addrspace(10)*, i8 }, align 8 %19 = alloca [1 x {} addrspace(10)*], align 8 %20 = alloca { {} addrspace(10)*, i8 }, align 8 %.not179 = icmp eq i8 %159, 0, !dbg !8944 br i1 %.not179, label %pass45, label %pass48, !dbg !8943 pass45: ; preds = %pass43 %160 = getelementptr inbounds i8, i8 addrspace(11)* %138, i64 20, !dbg !8947 %161 = load i8, i8 addrspace(11)* %160, align 4, !dbg !8947, !tbaa !1339 %.not180 = icmp eq i8 %161, 1, !dbg !8949 br i1 %.not180, label %L925, label %L932, !dbg !8952 pass48: ; preds = %pass43, %pass34, %pass30, %pass26 call fastcc void @julia__eval_tree_array_20774({ {} addrspace(10)*, i8 }* noalias nocapture nonnull sret %32, [1 x {} addrspace(10)*]* noalias nocapture nonnull "enzymejl_returnRoots" %36, {} addrspace(10)* %42, {} addrspace(10)* %3) #77, !dbg !8953 %162 = getelementptr inbounds { {} addrspace(10)*, i8 }, { {} addrspace(10)*, i8 }* %32, i64 0, i32 1, !dbg !8954 %163 = load i8, i8* %162, align 8, !dbg !8955, !tbaa !1389, !range !1421 %.not181 = icmp eq i8 %163, 1, !dbg !8909 %164 = getelementptr inbounds { {} addrspace(10)*, i8 }, { {} addrspace(10)*, i8 }* %32, i64 0, i32 0, !dbg !8956 br i1 %.not181, label %L955, label %L954, !dbg !8909 oob50: ; preds = %L1106 %165 = alloca i64, align 8, !dbg !8932 store i64 %99, i64* %165, align 8, !dbg !8932, !noalias !8903 %166 = addrspacecast {} addrspace(10)* %82 to {} addrspace(12)*, !dbg !8932 call void @jl_bounds_error_ints({} addrspace(12)* %166, i64* nonnull %165, i64 1) #77, !dbg !8932 unreachable, !dbg !8932 idxend51: ; preds = %L1106 %167 = load double addrspace(13)*, double addrspace(13)* addrspace(11)* %98, align 8, !dbg !8932, !tbaa !249, !alias.scope !8903, !nonnull !4 %168 = getelementptr inbounds double, double addrspace(13)* %167, i64 %value_phi49190, !dbg !8932 %169 = load double, double addrspace(13)* %168, align 8, !dbg !8932, !tbaa !251 %170 = call double @julia_sin_20915(double %169) #77, !dbg !8926 store double %170, double addrspace(13)* %168, align 8, !dbg !8957, !tbaa !251, !noalias !8903 %exitcond197.not = icmp eq i64 %99, %92, !dbg !8959 br i1 %exitcond197.not, label %L1126, label %L1106, !dbg !8928, !llvm.loop !8961 } rep: %35 = bitcast {} addrspace(10)* %34 to { {} addrspace(10)*, i8 } addrspace(10)*, !enzyme_caststack !4 prev: %6 = alloca { {} addrspace(10)*, i8 }, align 8 inst: %.sink = phi { {} addrspace(10)*, i8 }* [ %12, %L255 ], [ %10, %L248 ], [ %8, %L212 ], [ %6, %L176 ] Illegal address space propagation UNREACHABLE executed at /workspace/srcdir/Enzyme/enzyme/Enzyme/FunctionUtils.cpp:452! ```
wsmoses commented 5 months ago

Not sure, since when you tried main it was already fixed. I just released a new version.

On Wed, Jan 3, 2024 at 3:15 AM Miles Cranmer @.***> wrote:

Awesome. Seems to work now! What was the thing that fixed it?

— Reply to this email directly, view it on GitHub https://github.com/SymbolicML/DynamicExpressions.jl/pull/52#issuecomment-1874988500, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJTUXATTWERV4DYJY3HPZ3YMUHRNAVCNFSM6AAAAAA3PDTW6SVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTQNZUHE4DQNJQGA . You are receiving this because you commented.Message ID: @.***>

wsmoses commented 5 months ago

Yeah go ahead and open issues for anything else that arose. My comment only referred to the correctness issue you raised earlier.

MilesCranmer commented 5 months ago

Thanks! No worries. Just for simplicity I'll keep the integration tests on Julia 1.10 and ubuntu-latest, just to get this implemented finally.

Do you want me to add an integration test to Enzyme itself once this all passes and gets released?