EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
447 stars 63 forks source link

Error when using `view` of `Const` to calculate `view` of `Duplicated` #1956

Open hexaeder opened 1 week ago

hexaeder commented 1 week ago

I am trying to make the RHS of an ODEProblem Enzyme compatible. My function has the signature (du, u, p, t) and I try to differentiate du for u for constant p and t. I hit the error

ERROR: Constant memory is stored (or returned) to a differentiable variable.
As a result, Enzyme cannot provably ensure correctness and throws this error.

for some operations which use p in a calculation for du. I am quite new to Enzyme and don't fully understand this error, but on very simple examples it isn't a problem to use Const(p) to calculate Duplicated(du).

I boiled it down to 2 MWEs. The first MWE is closer to my actual code, including loop unrolling. The second MWE seems to error because of the broadcasting but does not need the loop unrolling to fail. I am not sure whether both demonstrate the same or different problems.

Both Examples have been created on Julia 1.10.5 and Enzyme 0.13.8. I am aware of set_runtime_activity, which works for forward mode in my actual example but segfaults for reverse mode...

MWE 1

using Pkg
pkg"activate --temp"
pkg"add Enzyme"
using Enzyme: Enzyme

@inline function unrolled_foreach(f::F, t::Tuple) where {F}
    f(first(t))
    @inline unrolled_foreach(f, Base.tail(t))
end
@inline unrolled_foreach(f::F, t::Tuple{}) where {F} = nothing

struct Functor{T}
    batches::T
end
function (f::Functor)(du, u, p, t)
    unrolled_foreach(f.batches) do batch
        for i in 1:2
            start = 1 + (i-1) * 2
            stop = start + 1
            range = start:stop

            _du  = view(du, range)
            _p   = view(p, range)
            _du[1] = _p[1]
        end
    end
    nothing
end

batches = (1,)
f = Functor(batches)

# test normal call
dx, x, p, t = zeros(4), zeros(4), collect(1.0:4.0), NaN
f(dx, x, p, t)
dx

# f_and_df = Enzyme.Duplicated(f, Enzyme.make_zero(f))
dxD = Enzyme.Duplicated(zeros(4), zeros(4))
xD = Enzyme.Duplicated(x, [1.0, 0.0, 0.0, 0.0])
pC = Enzyme.Const(p)
tC = Enzyme.Const(NaN)
Enzyme.autodiff(Enzyme.Forward, f, dxD, xD, pC, tC)

MWE 2

using Pkg
pkg"activate --temp"
pkg"add Enzyme"
using Enzyme: Enzyme

struct Functor{RT}
    range::RT
end
function (f::Functor)(du, u, p, t)
    r = f.range
    # r = 1:4 # this literal would work
    _du  = view(du, r)
    _p   = view(p, r)
    _du .= _p
    nothing
end

f = Functor(1:4)

# test normal function call
dx, x, p, t = zeros(4), zeros(4), collect(1.0:4.0), NaN
f(dx, x, p, t)
@assert dx == 1:4

dxD = Enzyme.Duplicated(zeros(4), zeros(4))
xD = Enzyme.Duplicated(x, [1.0, 0.0, 0.0, 0.0])
pC = Enzyme.Const(p)
tC = Enzyme.Const(NaN)
Enzyme.autodiff(Enzyme.Forward, f, dxD, xD, pC, tC)
wsmoses commented 6 days ago

@hexaeder is this the code which segfaults? in reverse mode (the error message implies you should use runtime activity so that seems like the resolution). However, a segfault is clearly bad so I want to make sure we fix that.

hexaeder commented 5 days ago

Hi! The segfault appears in my actual code, but I wasn't able to reproduce it to a MWE. If your interested I can try to set up a script which sets up the full objects using my packages and segfaults on jacobian call, but I'm not sure how easy debugging in there would be.

I was mainly posting because the error message said something like report if you think enzyme should be able to prove no runtime activity. And I don't see why the MWEs would contain runtime activity...

wsmoses commented 5 days ago

yeah that would be helpful [it should never segfault].

It is though confusing why it would require runtime activity here indeed

hexaeder commented 4 days ago

Unfortunately, I was not able to reproduce the problem with a plain Enzyme call, only in a jacobian call of DifferentiationInterface. I tried with plain Enzyme.autodiff but might be related to the batch-mode used by DI which I am not able to invoke myself.

Additional observations:

using Pkg
@assert VERSION == v"1.10.5"
pkg"activate --temp"
pkg"add NetworkDynamics#3e99370, Enzyme, Graphs, DifferentiationInterface"
using NetworkDynamics, Graphs, Enzyme
using Enzyme: Enzyme
using DifferentiationInterface: DifferentiationInterface as DI

# we need to load some test utils from NetworkDynamics
include(joinpath(pkgdir(NetworkDynamics),"test","ComponentLibrary.jl"))

# setup of the system
g = complete_graph(4)
vf = Lib.kuramoto_second()
ef = [Lib.diffusion_odeedge(),
      Lib.kuramoto_edge(),
      Lib.kuramoto_edge(),
      Lib.diffusion_edge_fid(),
      Lib.diffusion_odeedge(),
      Lib.diffusion_edge_fid()]
nw = Network(g, vf, ef)

x0 = rand(dim(nw))
dx = zeros(dim(nw))
p0 = rand(pdim(nw))

# this is the rhs we want to differentiate
# the last argument is time but it is not used in the system so I use NaN.
nw(dx, x0, p0, NaN)

# fault
DI.jacobian(nw, dx, DI.AutoEnzyme(mode=Enzyme.set_runtime_activity(Enzyme.Reverse), function_annotation=Enzyme.Duplicated), x0, DI.Constant(p0), DI.Constant(NaN))
Trace ``` julia: /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/Instructions.h:1190: void llvm::ICmpInst::AssertOK(): Assertion `getOperand(0)->getType() == getOperand(1)->getType() && "Both operands to ICmp instruction are not of the same type!"' failed. [282022] signal (6.-6): Aborted in expression starting at REPL[17]:1 pthread_kill at /lib/x86_64-linux-gnu/libc.so.6 (unknown line) raise at /lib/x86_64-linux-gnu/libc.so.6 (unknown line) abort at /lib/x86_64-linux-gnu/libc.so.6 (unknown line) unknown function (ip: 0x7fbc3dda171a) __assert_fail at /lib/x86_64-linux-gnu/libc.so.6 (unknown line) AssertOK at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/Instructions.h:1190 [inlined] ICmpInst at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/Instructions.h:1245 [inlined] CreateICmp at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/IRBuilder.h:2180 CreateICmpNE at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/IRBuilder.h:2071 [inlined] addToInvertedPtrDiffe at /workspace/srcdir/Enzyme/enzyme/Enzyme/DiffeGradientUtils.cpp:1179 visitLoadLike at /workspace/srcdir/Enzyme/enzyme/Enzyme/AdjointGenerator.h:646 visitLoadInst at /workspace/srcdir/Enzyme/enzyme/Enzyme/AdjointGenerator.h:720 visit at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/InstVisitor.h:111 [inlined] CreatePrimalAndGradient at /workspace/srcdir/Enzyme/enzyme/Enzyme/EnzymeLogic.cpp:4303 recursivelyHandleSubfunction at /workspace/srcdir/Enzyme/enzyme/Enzyme/AdjointGenerator.h:5675 visitCallInst at /workspace/srcdir/Enzyme/enzyme/Enzyme/AdjointGenerator.h:6412 visit at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/InstVisitor.h:111 [inlined] CreatePrimalAndGradient at /workspace/srcdir/Enzyme/enzyme/Enzyme/EnzymeLogic.cpp:4303 EnzymeCreatePrimalAndGradient at /workspace/srcdir/Enzyme/enzyme/Enzyme/CApi.cpp:631 EnzymeCreatePrimalAndGradient at /home/hw/.julia/packages/Enzyme/vgArw/src/api.jl:253 unknown function (ip: 0x7fbb4ea0c5de) _jl_invoke at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined] ijl_apply_generic at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/gf.c:3077 enzyme! at /home/hw/.julia/packages/Enzyme/vgArw/src/compiler.jl:3999 unknown function (ip: 0x7fbb4ea0a6f8) _jl_invoke at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined] ijl_apply_generic at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/gf.c:3077 #codegen#19013 at /home/hw/.julia/packages/Enzyme/vgArw/src/compiler.jl:7098 codegen at /home/hw/.julia/packages/Enzyme/vgArw/src/compiler.jl:5931 [inlined] _thunk at /home/hw/.julia/packages/Enzyme/vgArw/src/compiler.jl:8206 _thunk at /home/hw/.julia/packages/Enzyme/vgArw/src/compiler.jl:8206 [inlined] cached_compilation at /home/hw/.julia/packages/Enzyme/vgArw/src/compiler.jl:8247 [inlined] thunkbase at /home/hw/.julia/packages/Enzyme/vgArw/src/compiler.jl:8379 unknown function (ip: 0x7fbb4ea5ef00) _jl_invoke at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined] ijl_apply_generic at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/gf.c:3077 #s2070#19068 at /home/hw/.julia/packages/Enzyme/vgArw/src/compiler.jl:8516 [inlined] #s2070#19068 at ./none:0 _jl_invoke at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined] ijl_apply_generic at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/gf.c:3077 GeneratedFunctionStub at ./boot.jl:602 _jl_invoke at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined] ijl_apply_generic at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/gf.c:3077 jl_call_staged at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/method.c:540 ijl_code_for_staged at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/method.c:593 get_staged at ./compiler/utilities.jl:123 retrieve_code_info at ./compiler/utilities.jl:135 [inlined] InferenceState at ./compiler/inferencestate.jl:430 typeinf_edge at ./compiler/typeinfer.jl:920 abstract_call_method at ./compiler/abstractinterpretation.jl:629 abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:95 abstract_call_known at ./compiler/abstractinterpretation.jl:2087 abstract_call at ./compiler/abstractinterpretation.jl:2169 abstract_call at ./compiler/abstractinterpretation.jl:2162 abstract_call at ./compiler/abstractinterpretation.jl:2354 abstract_eval_call at ./compiler/abstractinterpretation.jl:2370 abstract_eval_statement_expr at ./compiler/abstractinterpretation.jl:2380 abstract_eval_statement at ./compiler/abstractinterpretation.jl:2624 abstract_eval_basic_statement at ./compiler/abstractinterpretation.jl:2889 typeinf_local at ./compiler/abstractinterpretation.jl:3098 typeinf_nocycle at ./compiler/abstractinterpretation.jl:3186 _typeinf at ./compiler/typeinfer.jl:247 typeinf at ./compiler/typeinfer.jl:216 typeinf_edge at ./compiler/typeinfer.jl:930 abstract_call_method at ./compiler/abstractinterpretation.jl:629 abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:95 abstract_call_known at ./compiler/abstractinterpretation.jl:2087 abstract_call at ./compiler/abstractinterpretation.jl:2169 abstract_apply at ./compiler/abstractinterpretation.jl:1612 abstract_call_known at ./compiler/abstractinterpretation.jl:2004 abstract_call at ./compiler/abstractinterpretation.jl:2169 abstract_call at ./compiler/abstractinterpretation.jl:2162 abstract_call at ./compiler/abstractinterpretation.jl:2354 abstract_eval_call at ./compiler/abstractinterpretation.jl:2370 abstract_eval_statement_expr at ./compiler/abstractinterpretation.jl:2380 abstract_eval_statement at ./compiler/abstractinterpretation.jl:2624 abstract_eval_basic_statement at ./compiler/abstractinterpretation.jl:2913 typeinf_local at ./compiler/abstractinterpretation.jl:3098 typeinf_nocycle at ./compiler/abstractinterpretation.jl:3186 _typeinf at ./compiler/typeinfer.jl:247 typeinf at ./compiler/typeinfer.jl:216 typeinf_edge at ./compiler/typeinfer.jl:930 abstract_call_method at ./compiler/abstractinterpretation.jl:629 abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:95 abstract_call_known at ./compiler/abstractinterpretation.jl:2087 abstract_call at ./compiler/abstractinterpretation.jl:2169 abstract_apply at ./compiler/abstractinterpretation.jl:1612 abstract_call_known at ./compiler/abstractinterpretation.jl:2004 abstract_call at ./compiler/abstractinterpretation.jl:2169 abstract_call at ./compiler/abstractinterpretation.jl:2162 abstract_call at ./compiler/abstractinterpretation.jl:2354 abstract_eval_call at ./compiler/abstractinterpretation.jl:2370 abstract_eval_statement_expr at ./compiler/abstractinterpretation.jl:2380 abstract_eval_statement at ./compiler/abstractinterpretation.jl:2624 abstract_eval_basic_statement at ./compiler/abstractinterpretation.jl:2913 typeinf_local at ./compiler/abstractinterpretation.jl:3098 typeinf_nocycle at ./compiler/abstractinterpretation.jl:3186 _typeinf at ./compiler/typeinfer.jl:247 typeinf at ./compiler/typeinfer.jl:216 typeinf_edge at ./compiler/typeinfer.jl:930 abstract_call_method at ./compiler/abstractinterpretation.jl:629 abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:95 abstract_call_known at ./compiler/abstractinterpretation.jl:2087 abstract_call at ./compiler/abstractinterpretation.jl:2169 abstract_apply at ./compiler/abstractinterpretation.jl:1612 abstract_call_known at ./compiler/abstractinterpretation.jl:2004 abstract_call at ./compiler/abstractinterpretation.jl:2169 abstract_call at ./compiler/abstractinterpretation.jl:2162 abstract_call at ./compiler/abstractinterpretation.jl:2354 abstract_eval_call at ./compiler/abstractinterpretation.jl:2370 abstract_eval_statement_expr at ./compiler/abstractinterpretation.jl:2380 abstract_eval_statement at ./compiler/abstractinterpretation.jl:2624 abstract_eval_basic_statement at ./compiler/abstractinterpretation.jl:2889 typeinf_local at ./compiler/abstractinterpretation.jl:3098 typeinf_nocycle at ./compiler/abstractinterpretation.jl:3186 _typeinf at ./compiler/typeinfer.jl:247 typeinf at ./compiler/typeinfer.jl:216 typeinf_edge at ./compiler/typeinfer.jl:930 abstract_call_method at ./compiler/abstractinterpretation.jl:629 abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:95 abstract_call_known at ./compiler/abstractinterpretation.jl:2087 abstract_call at ./compiler/abstractinterpretation.jl:2169 abstract_apply at ./compiler/abstractinterpretation.jl:1612 abstract_call_known at ./compiler/abstractinterpretation.jl:2004 abstract_call at ./compiler/abstractinterpretation.jl:2169 abstract_call at ./compiler/abstractinterpretation.jl:2162 abstract_call at ./compiler/abstractinterpretation.jl:2354 abstract_eval_call at ./compiler/abstractinterpretation.jl:2370 abstract_eval_statement_expr at ./compiler/abstractinterpretation.jl:2380 abstract_eval_statement at ./compiler/abstractinterpretation.jl:2624 abstract_eval_basic_statement at ./compiler/abstractinterpretation.jl:2913 typeinf_local at ./compiler/abstractinterpretation.jl:3098 typeinf_nocycle at ./compiler/abstractinterpretation.jl:3186 _typeinf at ./compiler/typeinfer.jl:247 typeinf at ./compiler/typeinfer.jl:216 typeinf_ext at ./compiler/typeinfer.jl:1051 typeinf_ext_toplevel at ./compiler/typeinfer.jl:1082 typeinf_ext_toplevel at ./compiler/typeinfer.jl:1078 jfptr_typeinf_ext_toplevel_35703.1 at /home/hw/.julia/juliaup/julia-1.10.5+0.x64.linux.gnu/lib/julia/sys.so (unknown line) _jl_invoke at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined] ijl_apply_generic at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/gf.c:3077 jl_apply at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/julia.h:1982 [inlined] jl_type_infer at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/gf.c:394 jl_generate_fptr_impl at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/jitlayers.cpp:504 jl_compile_method_internal at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/gf.c:2481 [inlined] jl_compile_method_internal at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/gf.c:2368 _jl_invoke at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/gf.c:2887 [inlined] ijl_apply_generic at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/gf.c:3077 jacobian at /home/hw/.julia/packages/DifferentiationInterface/Y5WaD/src/fallbacks/no_prep.jl:75 unknown function (ip: 0x7fbb4ea5bfa9) _jl_invoke at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined] ijl_apply_generic at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/gf.c:3077 jl_apply at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/julia.h:1982 [inlined] do_call at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/interpreter.c:126 eval_value at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/interpreter.c:223 eval_stmt_value at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/interpreter.c:174 [inlined] eval_body at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/interpreter.c:617 jl_interpret_toplevel_thunk at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/interpreter.c:775 jl_toplevel_eval_flex at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/toplevel.c:934 jl_toplevel_eval_flex at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/toplevel.c:877 ijl_toplevel_eval_in at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/toplevel.c:985 eval at ./boot.jl:385 [inlined] eval_user_input at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:150 repl_backend_loop at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:246 #start_repl_backend#46 at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:231 start_repl_backend at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:228 _jl_invoke at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined] ijl_apply_generic at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/gf.c:3077 #run_repl#59 at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:389 run_repl at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:375 jfptr_run_repl_91805.1 at /home/hw/.julia/juliaup/julia-1.10.5+0.x64.linux.gnu/lib/julia/sys.so (unknown line) _jl_invoke at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined] ijl_apply_generic at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/gf.c:3077 #1013 at ./client.jl:432 jfptr_YY.1013_82772.1 at /home/hw/.julia/juliaup/julia-1.10.5+0.x64.linux.gnu/lib/julia/sys.so (unknown line) _jl_invoke at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined] ijl_apply_generic at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/gf.c:3077 jl_apply at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/julia.h:1982 [inlined] jl_f__call_latest at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/builtins.c:812 #invokelatest#2 at ./essentials.jl:892 [inlined] invokelatest at ./essentials.jl:889 [inlined] run_main_repl at ./client.jl:416 exec_options at ./client.jl:333 _start at ./client.jl:552 jfptr__start_82798.1 at /home/hw/.julia/juliaup/julia-1.10.5+0.x64.linux.gnu/lib/julia/sys.so (unknown line) _jl_invoke at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined] ijl_apply_generic at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/gf.c:3077 jl_apply at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/julia.h:1982 [inlined] true_main at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/jlapi.c:582 jl_repl_entrypoint at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/jlapi.c:731 main at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/cli/loader_exe.c:58 unknown function (ip: 0x7fbc3dda2d8f) __libc_start_main at /lib/x86_64-linux-gnu/libc.so.6 (unknown line) unknown function (ip: 0x4010b8) Allocations: 79676851 (Pool: 79532925; Big: 143926); GC: 48 [1] 282022 IOT instruction (core dumped) julia +lts --startup-file=no ```
wsmoses commented 4 days ago

@gdalle re DI segfault

gdalle commented 4 days ago

In this case, nw seems to be an out-of-place function? If so, DI.jacobian uses split reverse mode with autodiff_thunk to be able to pass arbitrary adjoints with array outputs. It also wraps BatchDuplicated around the inputs (and around the function itself because of function_annotation). So a pure-Enzyme MWE would have to use all of these ingredients. Can you maybe boil it down to a simpler function passed to DI.jacobian? That would facilitate our investigation. Also note that because of this split reverse mode, any information you pass to the mode object inside AutoEnzyme is currently lost (because I use a split mode instead of a standard mode). Once https://github.com/EnzymeAD/Enzyme.jl/pull/1979 is merged, I can perform a better conversion and preserve settings like runtime activity.

hexaeder commented 4 days ago

Ah I didn't realize that the Segfault only happens in the more complex usecase by DI, not by my very simple Enzyme.autodiff call.

When using DI.jacobian, both MWEs from the initial post actually crash Julia. Because its shorter, here's the second one:

using Pkg
pkg"activate --temp"
pkg"add Enzyme, DifferentiationInterface"
using Enzyme: Enzyme
using DifferentiationInterface: DifferentiationInterface as DI

struct Functor{RT}
    range::RT
end
function (f::Functor)(du, u, p, t)
    r = f.range
    # r = 1:4 # this literal would work
    _du  = view(du, r)
    _p   = view(p, r)
    _du .= _p
    nothing
end

f = Functor(1:4)

# test normal function call
dx, x, p, t = zeros(4), zeros(4), collect(1.0:4.0), NaN
f(dx, x, p, t)
@assert dx == 1:4

#💣 
DI.jacobian(f, dx, DI.AutoEnzyme(mode=Enzyme.set_runtime_activity(Enzyme.Reverse), function_annotation=Enzyme.Duplicated), x, DI.Constant(p), DI.Constant(NaN))
gdalle commented 4 days ago

Thanks for the smaller MWE! Here we are working with an in-place function, so DI can use autodiff directly and the remarks about split mode from earlier don't apply. Also, your function is a functor but it does not contain differentiable data, so the right annotation here would be function_annotation=Enzyme.Const and not Enzyme.Duplicated. Actually we can dispense with the annotation altogether, because Enzyme can prove that this enclosed data is read-only.

So in the end, this is a tale of two runtime activities. This version errors but the error is pretty self-explanatory: you copied data from p (constant) to the output dx (differentiable).

backend_errors = DI.AutoEnzyme(; mode=Enzyme.Reverse)
DI.jacobian(f, dx, backend_errors, x, DI.Constant(p), DI.Constant(NaN))
ERROR: Constant memory is stored (or returned) to a differentiable variable.
As a result, Enzyme cannot provably ensure correctness and throws this error.
This might be due to the use of a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Runtime-Activity).
If Enzyme should be able to prove this use non-differentable, open an issue!
To work around this issue, either:
 a) rewrite this variable to not be conditionally active (fastest, but requires a code change), or
 b) set the Enzyme mode to turn on runtime activity (e.g. autodiff(set_runtime_activity(Reverse), ...) ). This will maintain correctness, but may slightly reduce performance.

Meanwhile, this version segfaults:

backend_segfaults = DI.AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Reverse))
DI.jacobian(f, dx, backend_segfaults, x, DI.Constant(p), DI.Constant(NaN))

My best guess is that the problem comes from the runtime activity analysis?

wsmoses commented 4 days ago

@gdalle can you boil out the DI sugar to something which errs with just Enzyme calls? And paste the stack trace

gdalle commented 4 days ago

Sounds good, I'll try that tomorrow, logging out for the day!