EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
455 stars 63 forks source link

50% correctness with certain Turing models #659

Closed torfjelde closed 1 year ago

torfjelde commented 1 year ago

So this is quite a fun one:

julia> using Distributions, DynamicPPL, LogDensityProblems, LogDensityProblemsAD, Enzyme, LinearAlgebra

julia> @model function linear_regression(x)
           α ~ Normal(0, 1)
           β ~ Normal(0, 1)
           μ = α .+ β * x
           y ~ MvNormal(μ, I)
       end
linear_regression (generic function with 2 methods)

julia> x = randn(10);

julia> model = linear_regression(x);

julia> cmodel = model | (y = model(),);

julia> f = ADgradient(:Enzyme, DynamicPPL.LogDensityFunction(cmodel, SimpleVarInfo((α = 0.0, β = 1.0))));

julia> LogDensityProblems.logdensity_and_gradient(f, [0.0, 1.0])
        (-16.97124612477541, [-5.18786388795403, -3.470190388619807])

julia> LogDensityProblems.logdensity_and_gradient(f, [0.0, 1.0])
(-16.97124612477541, [5.18786388795403, -9.41777706401886])

julia> LogDensityProblems.logdensity_and_gradient(f, [0.0, 1.0]) # different gradient!
(-16.97124612477541, [-5.18786388795403, -3.470190388619807])

julia> LogDensityProblems.logdensity_and_gradient(f, [0.0, 1.0]) # and we're back again
(-16.97124612477541, [5.18786388795403, -9.41777706401886])

julia> LogDensityProblems.logdensity_and_gradient(f, [0.0, 1.0]) # but then we gone
(-16.97124612477541, [-5.18786388795403, -3.470190388619807])

One of these is correct, as can be seen by using ForwardDiff:

julia> using ForwardDiff

julia> f_forwarddiff = ADgradient(:ForwardDiff, DynamicPPL.LogDensityFunction(cmodel, SimpleVarInfo((α = 0.0, β = 1.0))));

julia> LogDensityProblems.logdensity_and_gradient(f_forwarddiff, [0.0, 1.0])
(-16.97124612477541, [5.18786388795403, -9.417777064018862])

julia> LogDensityProblems.logdensity_and_gradient(f_forwarddiff, [0.0, 1.0])
(-16.97124612477541, [5.18786388795403, -9.417777064018862])

It's all very interesting :thinking:

wsmoses commented 1 year ago

Well this is definitely interesting. 1) can you try Enzyme#main 2) Can you rewrite this to directly call autodiff?

torfjelde commented 1 year ago

Can you rewrite this to directly call autodiff?

This is a good shout; could be a bug in how LogDensityProblemsAD uses Enzyme?

wsmoses commented 1 year ago

Wait, did that fix it?

I mean at minimum it lets us try to analyze the Enzyme-specific parts of the bug (and then write a test).

devmotion commented 1 year ago

could be a bug in how LogDensityProblemsAD uses Enzyme?

This would be interesting since @wsmoses also contributed to it and improved my initial version :stuck_out_tongue: There's also no cache involved (https://github.com/tpapp/LogDensityProblemsAD.jl/blob/e184a8ed5bca9c90b28b335de291b8390f868789/ext/LogDensityProblemsADEnzymeExt.jl#L70-L75), so I assume it's unlikely that the interface in LogDensityProblemsAD is the main culprit. (BTW already opened a PR for upcoming features in Enzyme 0.11: https://github.com/tpapp/LogDensityProblemsAD.jl/pull/10)

torfjelde commented 1 year ago

Wait, did that fix it?

Ah no, I haven't had the chance to try yet. Just thought it could be possible that it had something to do with this. But given @devmotion 's comment it seems less likely.

Again, haven't had the time to try it yet though.

devmotion commented 1 year ago
  1. can you try Enzyme#main

I can reproduce the example with Enzyme#main + EnzymeCore#main.

  1. Can you rewrite this to directly call autodiff?

The example can be reproduced also with autodiff (tested with Enzyme#main + EnzymeCore#main):

julia> ... # the definitions above

julia> LogDensityProblems.logdensity_and_gradient(f, [0.0, 1.0])
(-34.79163260301861, [7.0148317141196355, -28.486511693195027])

julia> LogDensityProblems.logdensity_and_gradient(f, [0.0, 1.0])
(-34.79163260301861, [-7.0148317141196355, -20.042228715930044])

julia> logdensity = Base.Fix1(LogDensityProblems.logdensity, DynamicPPL.LogDensityFunction(cmodel, SimpleVarInfo((α = 0.0, β = 1.0))));

julia> logdensity([0.0, 1.0])
-34.79163260301861

julia> autodiff(Enzyme.Reverse, logdensity, Duplicated([0.0, 1.0], Δx))
((nothing,),)

julia> Δx
2-element Vector{Float64}:
   7.0148317141196355
 -28.486511693195027

julia> Δx = zeros(2)
2-element Vector{Float64}:
 0.0
 0.0

julia> autodiff(Enzyme.Reverse, logdensity, Duplicated([0.0, 1.0], Δx))
((nothing,),)

julia> Δx
2-element Vector{Float64}:
  -7.0148317141196355
 -20.042228715930044

It's possible to eliminate LogDensityProblems completely but currently I have to restart + rerun since Julia segfaulted when I tried to use ReverseWithPrimal:

julia> autodiff(Enzyme.ReverseWithPrimal, logdensity, Duplicated([0.0, 1.0], Δx))
 rep:   %56 = bitcast {}* %35 to { { [2 x double], double }, {} addrspace(10)* }* prev:   %57 = bitcast {} addrspace(10)* %37 to { { [2 x double], double }, {} addrspace(10)* } addrspace(10)*, !enzyme_caststack !4 inst:   store { { [2 x double], double }, {} addrspace(10)* } addrspace(10)* %57, { { [2 x double], double }, {} addrspace(10)* } addrspace(10)** %58, align 8
Illegal address space propagation
UNREACHABLE executed at /workspace/srcdir/Enzyme/enzyme/Enzyme/FunctionUtils.cpp:406!

signal (6): Aborted
in expression starting at REPL[25]:1
__pthread_kill_implementation at /lib64/libc.so.6 (unknown line)
gsignal at /lib64/libc.so.6 (unknown line)
abort at /lib64/libc.so.6 (unknown line)
_ZN4llvm25llvm_unreachable_internalEPKcS1_j at /home/david/.julia/juliaup/julia-1.8.5+0.x64.linux.gnu/bin/../lib/julia/libLLVM-13jl.so (unknown line)
RecursivelyReplaceAddressSpace at /workspace/srcdir/Enzyme/enzyme/Enzyme/FunctionUtils.cpp:406
LowerAllocAddr at /workspace/srcdir/Enzyme/enzyme/Enzyme/FunctionUtils.cpp:722
CreateAugmentedPrimal at /workspace/srcdir/Enzyme/enzyme/Enzyme/EnzymeLogic.cpp:2648
visitCallInst at /workspace/srcdir/Enzyme/enzyme/Enzyme/AdjointGenerator.h:12137
delegateCallInst at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/InstVisitor.h:302 [inlined]
visitCall at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/Instruction.def:209 [inlined]
visit at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/Instruction.def:209
visit at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/InstVisitor.h:112 [inlined]
CreatePrimalAndGradient at /workspace/srcdir/Enzyme/enzyme/Enzyme/EnzymeLogic.cpp:3938
visitCallInst at /workspace/srcdir/Enzyme/enzyme/Enzyme/AdjointGenerator.h:12497
delegateCallInst at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/InstVisitor.h:302 [inlined]
visitCall at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/Instruction.def:209 [inlined]
visit at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/Instruction.def:209
visit at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/InstVisitor.h:112 [inlined]
CreatePrimalAndGradient at /workspace/srcdir/Enzyme/enzyme/Enzyme/EnzymeLogic.cpp:3938
EnzymeCreatePrimalAndGradient at /workspace/srcdir/Enzyme/enzyme/Enzyme/CApi.cpp:470
EnzymeCreatePrimalAndGradient at /home/david/.julia/packages/Enzyme/PjX8s/src/api.jl:123
enzyme! at /home/david/.julia/packages/Enzyme/PjX8s/src/compiler.jl:6324
unknown function (ip: 0x7f17f4ba1e83)
_jl_invoke at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/gf.c:2377 [inlined]
ijl_apply_generic at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/gf.c:2559
#codegen#134 at /home/david/.julia/packages/Enzyme/PjX8s/src/compiler.jl:7578
codegen##kw at /home/david/.julia/packages/Enzyme/PjX8s/src/compiler.jl:7208 [inlined]
_thunk at /home/david/.julia/packages/Enzyme/PjX8s/src/compiler.jl:8090 [inlined]
_thunk at /home/david/.julia/packages/Enzyme/PjX8s/src/compiler.jl:8084
unknown function (ip: 0x7f17f4b9cc3d)
_jl_invoke at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/gf.c:2377 [inlined]
ijl_apply_generic at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/gf.c:2559
cached_compilation at /home/david/.julia/packages/Enzyme/PjX8s/src/compiler.jl:8128
unknown function (ip: 0x7f17f4f4cef3)
_jl_invoke at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/gf.c:2377 [inlined]
ijl_apply_generic at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/gf.c:2559
#s451#163 at /home/david/.julia/packages/Enzyme/PjX8s/src/compiler.jl:8188 [inlined]
#s451#163 at ./none:0
_jl_invoke at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/gf.c:2377 [inlined]
ijl_apply_generic at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/gf.c:2559
GeneratedFunctionStub at ./boot.jl:582
_jl_invoke at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/gf.c:2377 [inlined]
ijl_apply_generic at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/gf.c:2559
jl_apply at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/julia.h:1843 [inlined]
jl_call_staged at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/method.c:520
ijl_code_for_staged at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/method.c:571
get_staged at ./compiler/utilities.jl:114
retrieve_code_info at ./compiler/utilities.jl:126 [inlined]
InferenceState at ./compiler/inferencestate.jl:284
typeinf_edge at ./compiler/typeinfer.jl:868
abstract_call_method at ./compiler/abstractinterpretation.jl:647
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:139
abstract_call_known at ./compiler/abstractinterpretation.jl:1716
abstract_call at ./compiler/abstractinterpretation.jl:1786
abstract_call at ./compiler/abstractinterpretation.jl:1753
abstract_eval_statement at ./compiler/abstractinterpretation.jl:1910
typeinf_local at ./compiler/abstractinterpretation.jl:2386
typeinf_nocycle at ./compiler/abstractinterpretation.jl:2482
_typeinf at ./compiler/typeinfer.jl:230
typeinf at ./compiler/typeinfer.jl:213
typeinf_edge at ./compiler/typeinfer.jl:877
abstract_call_method at ./compiler/abstractinterpretation.jl:647
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:139
abstract_call_known at ./compiler/abstractinterpretation.jl:1716
abstract_call at ./compiler/abstractinterpretation.jl:1786
abstract_call at ./compiler/abstractinterpretation.jl:1753
abstract_eval_statement at ./compiler/abstractinterpretation.jl:1910
typeinf_local at ./compiler/abstractinterpretation.jl:2386
typeinf_nocycle at ./compiler/abstractinterpretation.jl:2482
_typeinf at ./compiler/typeinfer.jl:230
typeinf at ./compiler/typeinfer.jl:213
typeinf_edge at ./compiler/typeinfer.jl:877
abstract_call_method at ./compiler/abstractinterpretation.jl:647
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:139
abstract_call_known at ./compiler/abstractinterpretation.jl:1716
abstract_call at ./compiler/abstractinterpretation.jl:1786
abstract_call at ./compiler/abstractinterpretation.jl:1753
abstract_eval_statement at ./compiler/abstractinterpretation.jl:1910
typeinf_local at ./compiler/abstractinterpretation.jl:2360
typeinf_nocycle at ./compiler/abstractinterpretation.jl:2482
_typeinf at ./compiler/typeinfer.jl:230
typeinf at ./compiler/typeinfer.jl:213
typeinf_edge at ./compiler/typeinfer.jl:877
abstract_call_method at ./compiler/abstractinterpretation.jl:647
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:139
abstract_call_known at ./compiler/abstractinterpretation.jl:1716
abstract_call at ./compiler/abstractinterpretation.jl:1786
abstract_apply at ./compiler/abstractinterpretation.jl:1357
abstract_call_known at ./compiler/abstractinterpretation.jl:1620
abstract_call at ./compiler/abstractinterpretation.jl:1786
abstract_call at ./compiler/abstractinterpretation.jl:1753
abstract_eval_statement at ./compiler/abstractinterpretation.jl:1910
typeinf_local at ./compiler/abstractinterpretation.jl:2386
typeinf_nocycle at ./compiler/abstractinterpretation.jl:2482
_typeinf at ./compiler/typeinfer.jl:230
typeinf at ./compiler/typeinfer.jl:213
typeinf_ext at ./compiler/typeinfer.jl:967
typeinf_ext_toplevel at ./compiler/typeinfer.jl:1000
typeinf_ext_toplevel at ./compiler/typeinfer.jl:996
jfptr_typeinf_ext_toplevel_17539.clone_1 at /home/david/.julia/juliaup/julia-1.8.5+0.x64.linux.gnu/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/gf.c:2377 [inlined]
ijl_apply_generic at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/gf.c:2559
jl_apply at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/julia.h:1843 [inlined]
jl_type_infer at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/gf.c:315
jl_generate_fptr_impl at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/jitlayers.cpp:319
jl_compile_method_internal at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/gf.c:2091 [inlined]
jl_compile_method_internal at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/gf.c:2035
_jl_invoke at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/gf.c:2369 [inlined]
ijl_apply_generic at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/gf.c:2559
jl_apply at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/julia.h:1843 [inlined]
do_call at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/interpreter.c:126
eval_value at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/interpreter.c:215
eval_stmt_value at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/interpreter.c:166 [inlined]
eval_body at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/interpreter.c:612
jl_interpret_toplevel_thunk at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/interpreter.c:750
jl_toplevel_eval_flex at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/toplevel.c:906
jl_toplevel_eval_flex at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/toplevel.c:850
jl_toplevel_eval_flex at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/toplevel.c:850
jl_toplevel_eval_flex at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/toplevel.c:850
ijl_toplevel_eval_in at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/toplevel.c:965
eval at ./boot.jl:368 [inlined]
eval_user_input at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/usr/share/julia/stdlib/v1.8/REPL/src/REPL.jl:151
repl_backend_loop at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/usr/share/julia/stdlib/v1.8/REPL/src/REPL.jl:247
start_repl_backend at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/usr/share/julia/stdlib/v1.8/REPL/src/REPL.jl:232
#run_repl#47 at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/usr/share/julia/stdlib/v1.8/REPL/src/REPL.jl:369
run_repl at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/usr/share/julia/stdlib/v1.8/REPL/src/REPL.jl:355
jfptr_run_repl_65104.clone_1 at /home/david/.julia/juliaup/julia-1.8.5+0.x64.linux.gnu/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/gf.c:2377 [inlined]
ijl_apply_generic at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/gf.c:2559
#967 at ./client.jl:419
jfptr_YY.967_33139.clone_1 at /home/david/.julia/juliaup/julia-1.8.5+0.x64.linux.gnu/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/gf.c:2377 [inlined]
ijl_apply_generic at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/gf.c:2559
jl_apply at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/julia.h:1843 [inlined]
jl_f__call_latest at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/builtins.c:774
#invokelatest#2 at ./essentials.jl:729 [inlined]
invokelatest at ./essentials.jl:726 [inlined]
run_main_repl at ./client.jl:404
exec_options at ./client.jl:318
_start at ./client.jl:522
jfptr__start_38041.clone_1 at /home/david/.julia/juliaup/julia-1.8.5+0.x64.linux.gnu/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/gf.c:2377 [inlined]
ijl_apply_generic at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/gf.c:2559
jl_apply at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/julia.h:1843 [inlined]
true_main at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/jlapi.c:575
jl_repl_entrypoint at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/jlapi.c:719
main at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/cli/loader_exe.c:59
__libc_start_call_main at /lib64/libc.so.6 (unknown line)
__libc_start_main at /lib64/libc.so.6 (unknown line)
unknown function (ip: 0x401098)
Allocations: 114092184 (Pool: 113988457; Big: 103727); GC: 80
wsmoses commented 1 year ago

@devmotion I thought we got rid of the Fix1 earlier by passing another (const) argument?

devmotion commented 1 year ago

Yeah, that's not part of LogDensityProblems, only in my simpler example. I only remembered that Fix1/Fix2 are not needed in Enzyme when Julia had already segfaulted :smile:

devmotion commented 1 year ago

OK, completely without LogDensityProblems:

(jl_6gvLdX) pkg> add Enzyme#main EnzymeCore#main Distributions LogDensityProblems LogDensityProblemsAD

julia> using Distributions, DynamicPPL, LogDensityProblems, LogDensityProblemsAD, Enzyme, LinearAlgebra

julia> @model function linear_regression(x)
           α ~ Normal(0, 1)
           β ~ Normal(0, 1)
           μ = α .+ β * x
           y ~ MvNormal(μ, I)
       end
linear_regression (generic function with 2 methods)

julia> x = randn(10);

julia> model = linear_regression(x);

julia> cmodel = model | (y = model(),);

julia> logjoint(cmodel, SimpleVarInfo((α = 0.0, β = 1.0)))
-27.305144853104405

julia> f = ADgradient(:Enzyme, DynamicPPL.LogDensityFunction(cmodel, SimpleVarInfo((α = 0.0, β = 1.0))));

julia> LogDensityProblems.logdensity_and_gradient(f, [0.0, 1.0])
(-27.305144853104405, [9.942031055458013, 12.437679194765497])

julia> LogDensityProblems.logdensity_and_gradient(f, [0.0, 1.0]) # sample problem as above
(-27.305144853104405, [-9.942031055458013, -45.993444104062156])

julia> logp(model, α, β) = logjoint(model, SimpleVarInfo((; α, β)))
logp (generic function with 1 method)

julia> autodiff(Enzyme.Reverse, logp, cmodel, Active(0.0), Active(1.0))
((nothing, 9.942031055458013, 12.437679194765499),)

julia> autodiff(Enzyme.Reverse, logp, cmodel, Active(0.0), Active(1.0)) # sample problem, with `autodiff` and w/o LogDensityProblems
((nothing, -9.942031055458013, -45.993444104062156),)

julia> using ForwardDiff

julia> ForwardDiff.gradient(x -> logp(cmodel, x[1], x[2]), [0.0, 1.0])
2-element Vector{Float64}:
  9.942031055458015
 12.437679194765499

julia> autodiff(Enzyme.ReverseWithPrimal, logp, cmodel, Active(0.0), Active(1.0))
 rep:   %55 = bitcast {}* %34 to { { [2 x double], double }, {} addrspace(10)* }* prev:   %56 = bitcast {} addrspace(10)* %36 to { { [2 x double], double }, {} addrspace(10)* } addrspace(10)*, !enzyme_caststack !4 inst:   store { { [2 x double], double }, {} addrspace(10)* } addrspace(10)* %56, { { [2 x double], double }, {} addrspace(10)* } addrspace(10)** %57, align 8
Illegal address space propagation
UNREACHABLE executed at /workspace/srcdir/Enzyme/enzyme/Enzyme/FunctionUtils.cpp:406!

signal (6): Aborted
in expression starting at REPL[25]:1
__pthread_kill_implementation at /lib64/libc.so.6 (unknown line)
gsignal at /lib64/libc.so.6 (unknown line)
abort at /lib64/libc.so.6 (unknown line)
_ZN4llvm25llvm_unreachable_internalEPKcS1_j at /home/david/.julia/juliaup/julia-1.8.5+0.x64.linux.gnu/bin/../lib/julia/libLLVM-13jl.so (unknown line)
RecursivelyReplaceAddressSpace at /workspace/srcdir/Enzyme/enzyme/Enzyme/FunctionUtils.cpp:406
LowerAllocAddr at /workspace/srcdir/Enzyme/enzyme/Enzyme/FunctionUtils.cpp:722
CreateAugmentedPrimal at /workspace/srcdir/Enzyme/enzyme/Enzyme/EnzymeLogic.cpp:2648
visitCallInst at /workspace/s
...
devmotion commented 1 year ago

Another fun observation: Forward-mode returns yet another set of derivatives:

julia> ... # as above

julia> logp(cmodel, 0.0, 1.0)
-20.712423408307206

julia> autodiff(Reverse, logp, cmodel, Active(0.0), Active(1.0))
((nothing, 11.908124667958683, -7.800343891290028),)

julia> autodiff(Reverse, logp, cmodel, Active(0.0), Active(1.0)) # problem discussed above
((nothing, -11.908124667958683, -12.56997812841224),)

julia> autodiff(Forward, logp, BatchDuplicated, cmodel, BatchDuplicated(0.0, (1.0, 0.0)), BatchDuplicated(1.0, (0.0, 1.0)))
  %"'ipc68" = addrspacecast { [1 x {} addrspace(10)*] }* %14 to { [1 x {} addrspace(10)*] } addrspace(11)*, !dbg !132
  %49 = insertvalue [2 x { [1 x {} addrspace(10)*] } addrspace(11)*] undef, { [1 x {} addrspace(10)*] } addrspace(11)* %"'ipc68", 0, !dbg !132
  %"'ipc69" = addrspacecast { [1 x {} addrspace(10)*] }* %17 to { [1 x {} addrspace(10)*] } addrspace(11)*, !dbg !132
  %52 = insertvalue [2 x { [1 x {} addrspace(10)*] } addrspace(11)*] %51, { [1 x {} addrspace(10)*] } addrspace(11)* %"'ipc69", 1, !dbg !132
  %"'ipc30.i" = addrspacecast { [1 x {} addrspace(10)*], { [1 x {} addrspace(10)*] } }* %16 to { [1 x {} addrspace(10)*], { [1 x {} addrspace(10)*] } } addrspace(11)*, !dbg !47
  %52 = insertvalue [2 x { [1 x {} addrspace(10)*], { [1 x {} addrspace(10)*] } } addrspace(11)*] %51, { [1 x {} addrspace(10)*], { [1 x {} addrspace(10)*] } } addrspace(11)* %"'ipc30.i", 1, !dbg !47
  %"'ipc57.i" = addrspacecast { [1 x {} addrspace(10)*] }* %17 to { [1 x {} addrspace(10)*] } addrspace(11)*, !dbg !187
  %69 = insertvalue [2 x { [1 x {} addrspace(10)*] } addrspace(11)*] %68, { [1 x {} addrspace(10)*] } addrspace(11)* %"'ipc57.i", 1, !dbg !187
  %"'ipc56.i" = addrspacecast { [1 x {} addrspace(10)*] }* %18 to { [1 x {} addrspace(10)*] } addrspace(11)*, !dbg !187
  %69 = insertvalue [2 x { [1 x {} addrspace(10)*] } addrspace(11)*] undef, { [1 x {} addrspace(10)*] } addrspace(11)* %"'ipc56.i", 0, !dbg !187
  %"'ipc.i" = addrspacecast { [1 x {} addrspace(10)*], { [1 x {} addrspace(10)*] } }* %19 to { [1 x {} addrspace(10)*], { [1 x {} addrspace(10)*] } } addrspace(11)*, !dbg !47
  %54 = insertvalue [2 x { [1 x {} addrspace(10)*], { [1 x {} addrspace(10)*] } } addrspace(11)*] undef, { [1 x {} addrspace(10)*], { [1 x {} addrspace(10)*] } } addrspace(11)* %"'ipc.i", 0, !dbg !47
(-20.712423408307206, (var"1" = -6.462197351743586, var"2" = -26.1706659109923))
wsmoses commented 1 year ago

@devmotion okay yeah, this is going to need more minimization. In particular, trying to get the AD'd function as small/using as few dependencies as possible will allow us to investigate what/why is going on.

devmotion commented 1 year ago

Even simpler than the last examples? They already use the API in DynamicPPL without LogDensityProblems and are based only on input types defined in DynamicPPL.

wsmoses commented 1 year ago

Yes, e.g. trying to remove the dependency on DynamicPPL, custom macro, and even model ideally, etc.

devmotion commented 1 year ago

All fundamental things for the example (@model, the Model type itself, SimpleVarInfo, and the whole logic for logjoint) are defined in DynamicPPL, so there won't be anything left if it is removed I think :grimacing: It is already the really core condensed part of Turing.

wsmoses commented 1 year ago

Sorry, I mean try to remove the dependency -- not remove the code. Aka, can you inline the definition of these macros, packages, etc so I can run this code without using and/or understanding what these packages do?

We have to try to minimize the actual number of instructions generated by Julia to find what is going wrong, and if I don't understand what a syntax means or how a package works, I cannot do that effectively.

wsmoses commented 1 year ago

As a nice thing, it appears the latest Enzyme proper (aka jll ) main actually fixes the reverse mode correctness issue.

Investigating forward mode things.

devmotion commented 1 year ago

It involves so many internals that unwrapping is tedious and I do it step by step. But it gets more and more confusing since by unwrapping one step, I get other, completely different derivatives:

julia> ... # cmodel from above

julia> logp(model, α, β) = logjoint(model, SimpleVarInfo((; α, β)))
logp (generic function with 1 method)

julia> logp(cmodel, 0.0, 1.0)
-62.628671993804964

julia> autodiff(Enzyme.Reverse, logp, cmodel, Active(0.0), Active(1.0))
((nothing, -16.547138600446626, -28.39510526890181),)

julia> autodiff(Enzyme.Reverse, logp, cmodel, Active(0.0), Active(1.0))
((nothing, 16.547138600446626, -75.80771392179598),)

julia> logp2(model, α, β) = last(DynamicPPL._evaluate!!(model, SimpleVarInfo((; α, β)), DefaultContext())).logp
logp2 (generic function with 1 method)

julia> logp2(cmodel, 0.0, 1.0)
-62.628671993804964

julia> autodiff(Enzyme.Reverse, logp2, cmodel, Active(0.0), Active(1.0))
((nothing, -16.547138600446626, -28.395105268901798),)

julia> autodiff(Enzyme.Reverse, logp2, cmodel, Active(0.0), Active(1.0))
((nothing, 1.0824674490095276e-15, -0.9999999999999989),)

julia> autodiff(Enzyme.Reverse, logp2, cmodel, Active(0.0), Active(1.0))
((nothing, 0.0, -1.0),)

julia> autodiff(Enzyme.Reverse, logp2, cmodel, Active(0.0), Active(1.0))
((nothing, 0.0, -1.0),)
wsmoses commented 1 year ago

At this point, ignore the Reverse mode issue, since I think that is fixed as soon as we bump the jll. Minimizing the LLVM-level ReverseWithPrimal error, however, is most helpful.

edit: changed to say ReverseWithPrimal**

wsmoses commented 1 year ago

And that doesn't matter what it prints out, so long as it causes the LLVM-level error.

devmotion commented 1 year ago

The segfault was caused by ReverseWithPrimal, or are you referring to the % ... lines in the output?

I also ran forward mode and compared logp and logp2 from the comment above:

julia> autodiff(Forward, logp, BatchDuplicated, cmodel, BatchDuplicated(0.0, (1.0, 0.0)), BatchDuplicated(1.0, (0.0, 1.0)))
  %"'ipc68" = addrspacecast { [1 x {} addrspace(10)*] }* %14 to { [1 x {} addrspace(10)*] } addrspace(11)*, !dbg !132
  %49 = insertvalue [2 x { [1 x {} addrspace(10)*] } addrspace(11)*] undef, { [1 x {} addrspace(10)*] } addrspace(11)* %"'ipc68", 0, !dbg !132
  %"'ipc69" = addrspacecast { [1 x {} addrspace(10)*] }* %17 to { [1 x {} addrspace(10)*] } addrspace(11)*, !dbg !132
  %52 = insertvalue [2 x { [1 x {} addrspace(10)*] } addrspace(11)*] %51, { [1 x {} addrspace(10)*] } addrspace(11)* %"'ipc69", 1, !dbg !132
  %"'ipc57.i" = addrspacecast { [1 x {} addrspace(10)*] }* %16 to { [1 x {} addrspace(10)*] } addrspace(11)*, !dbg !187
  %68 = insertvalue [2 x { [1 x {} addrspace(10)*] } addrspace(11)*] %67, { [1 x {} addrspace(10)*] } addrspace(11)* %"'ipc57.i", 1, !dbg !187
  %"'ipc56.i" = addrspacecast { [1 x {} addrspace(10)*] }* %17 to { [1 x {} addrspace(10)*] } addrspace(11)*, !dbg !187
  %68 = insertvalue [2 x { [1 x {} addrspace(10)*] } addrspace(11)*] undef, { [1 x {} addrspace(10)*] } addrspace(11)* %"'ipc56.i", 0, !dbg !187
  %"'ipc30.i" = addrspacecast { [1 x {} addrspace(10)*], { [1 x {} addrspace(10)*] } }* %18 to { [1 x {} addrspace(10)*], { [1 x {} addrspace(10)*] } } addrspace(11)*, !dbg !47
  %54 = insertvalue [2 x { [1 x {} addrspace(10)*], { [1 x {} addrspace(10)*] } } addrspace(11)*] %53, { [1 x {} addrspace(10)*], { [1 x {} addrspace(10)*] } } addrspace(11)* %"'ipc30.i", 1, !dbg !47
  %"'ipc.i" = addrspacecast { [1 x {} addrspace(10)*], { [1 x {} addrspace(10)*] } }* %19 to { [1 x {} addrspace(10)*], { [1 x {} addrspace(10)*] } } addrspace(11)*, !dbg !47
  %54 = insertvalue [2 x { [1 x {} addrspace(10)*], { [1 x {} addrspace(10)*] } } addrspace(11)*] undef, { [1 x {} addrspace(10)*], { [1 x {} addrspace(10)*] } } addrspace(11)* %"'ipc.i", 0, !dbg !47
(-11.527262398456074, (var"1" = 0.0, var"2" = -1.0))

julia> autodiff(Forward, logp2, BatchDuplicated, cmodel, BatchDuplicated(0.0, (1.0, 0.0)), BatchDuplicated(1.0, (0.0, 1.0)))
  %"'ipc27.i" = addrspacecast { [1 x {} addrspace(10)*] }* %11 to { [1 x {} addrspace(10)*] } addrspace(11)*, !dbg !83
  %36 = insertvalue [2 x { [1 x {} addrspace(10)*] } addrspace(11)*] %35, { [1 x {} addrspace(10)*] } addrspace(11)* %"'ipc27.i", 1, !dbg !83
  %"'ipc26.i" = addrspacecast { [1 x {} addrspace(10)*] }* %12 to { [1 x {} addrspace(10)*] } addrspace(11)*, !dbg !83
  %36 = insertvalue [2 x { [1 x {} addrspace(10)*] } addrspace(11)*] undef, { [1 x {} addrspace(10)*] } addrspace(11)* %"'ipc26.i", 0, !dbg !83
(-11.527262398456074, (var"1" = 0.0, var"2" = -1.0))

So it seems at least they are consistent - and they match the result that the reverse-mode derivatives of logp2 "converge" to?

wsmoses commented 1 year ago

No sorry, I meant to say the ReverseWithPrimal. As you can see this is a new (not yet released) feature, hence needing some more love/testing to get all the errors back to Julia-level.

devmotion commented 1 year ago

So it seems at least they are consistent - and they match the result that the reverse-mode derivatives of logp2 "converge" to?

And they agree with ForwardDiff, it seems:

julia> ForwardDiff.gradient(x -> logp(cmodel, x[1], x[2]), [0.0, 1.0])
2-element Vector{Float64}:
  0.0
 -1.0

julia> ForwardDiff.gradient(x -> logp2(cmodel, x[1], x[2]), [0.0, 1.0])
2-element Vector{Float64}:
  0.0
 -1.0
wsmoses commented 1 year ago

It is weird that forward is printing though.... That just seems to be a println that wasn't removed (and we also should separately try to minimize and find).

devmotion commented 1 year ago

:slightly_smiling_face: I think the problem with ReverseWithPrimal can be found by looking at what changed when simplifying logp to logp2 since it works for logp2:

julia> autodiff(ReverseWithPrimal, logp2, cmodel, Active(0.0), Active(1.0))
((nothing, 0.0, -1.0), -11.527262398456074)

julia> logp2(cmodel, 0.0, 1.0)
-11.527262398456074

Interesting observation as well is that apparently the "convergence" observed above (everything still in the same Julia session) was caused by some unintended mutation of the model. So it seems running autodiff modified some of the values in the model even though it shouldn't?

I'll try to confirm this with a newly sampled model instance.

devmotion commented 1 year ago

OK, so ReverseWithPrimal seems to work with logp2 consistently and shows the same mutating behaviour:

julia> x = randn(10);

julia> model = linear_regression(x);

julia> cmodel = model | (y = model(),);

julia> logp(cmodel, 0.0, 1.0)
-25.473371825268227

julia> logp2(cmodel, 0.0, 1.0)
-25.473371825268227

julia> autodiff(ReverseWithPrimal, logp2, cmodel, Active(0.0), Active(1.0))
((nothing, -12.193412883437983, -9.213664141918544), -25.473371825268227)

julia> autodiff(ReverseWithPrimal, logp2, cmodel, Active(0.0), Active(1.0))
((nothing, -1.1102230246251565e-16, -1.0000000000000002), -11.527262398456074)

julia> autodiff(ReverseWithPrimal, logp2, cmodel, Active(0.0), Active(1.0))
((nothing, 0.0, -1.0), -11.527262398456074)
devmotion commented 1 year ago

Even more evidence that the different behaviour is caused by differences between DynamicPPL.evaluate_threadsafe!! (used by logjoint on my computer since I use multiple threads) and DynamicPPL.evaluate_threadunsafe!! (calls _evaluate!! in the logp2 definition above):

julia> ... # cmodel, logp, and logp2 as above

julia> # this is called by `logjoint` (on my computer), so behaviour should be the same as for `logp`
           logp3(model, α, β) = last(DynamicPPL.evaluate_threadsafe!!(model, SimpleVarInfo((; α, β)), DefaultContext())).logp

julia> # `evaluate_threadunsafe!!` calls `_evaluate!!`, so behaviour should be same as for `logp2`
           logp4(model, α, β) = last(DynamicPPL.evaluate_threadunsafe!!(model, SimpleVarInfo((; α, β)), DefaultContext())).logp

julia> logp(cmodel, 0.0, 1.0)
-60.156275507914806

julia> logp2(cmodel, 0.0, 1.0)
-60.156275507914806

julia> logp3(cmodel, 0.0, 1.0)
-60.156275507914806

julia> logp4(cmodel, 0.0, 1.0)
-60.156275507914806

julia> autodiff(Reverse, logp, cmodel, Active(0.0), Active(1.0))
((nothing, 26.83640954931467, -30.79381366952238),)

julia> autodiff(Reverse, logp, cmodel, Active(0.0), Active(1.0))
((nothing, -26.83640954931467, -68.46421254939509),)

julia> autodiff(Reverse, logp3, cmodel, Active(0.0), Active(1.0))
((nothing, 26.83640954931467, -30.79381366952238),)

julia> autodiff(Reverse, logp3, cmodel, Active(0.0), Active(1.0)) # same as logp
((nothing, -26.83640954931467, -68.46421254939509),)

julia> autodiff(Reverse, logp3, cmodel, Active(0.0), Active(1.0))
((nothing, 26.83640954931467, -30.79381366952238),)

julia> autodiff(Reverse, logp4, cmodel, Active(0.0), Active(1.0)) # same as logp2
((nothing, -26.83640954931467, -68.46421254939509),)

julia> autodiff(Reverse, logp4, cmodel, Active(0.0), Active(1.0))
((nothing, -2.220446049250313e-16, -1.0000000000000009),)

The functions are defined https://github.com/TuringLang/DynamicPPL.jl/blob/b23acff013a9111c8ce2c89dbf5339e76234d120/src/model.jl#L550 and https://github.com/TuringLang/DynamicPPL.jl/blob/b23acff013a9111c8ce2c89dbf5339e76234d120/src/model.jl#L565 Very roughly, the main difference is that evaluate_threadsafe!! uses a wrapper data structure that has very basic support for threaded log density evaluation by allowing users to add @threads for ... in their model. The design is not good (anymore): https://github.com/TuringLang/DynamicPPL.jl/issues/429 Could this use of threadid or some other Threads-related thing explain the differences in the results and why ReverseWithPrimal causes a segfault?

wsmoses commented 1 year ago

Both the correctness and LLVM-level errors should now be fixed on latest main. Please reopen if it persists.

devmotion commented 1 year ago

Great, thank you! I checked logp and logp2 (which should cover both branches), and Reverse, Forward, and ReverseWithPrimal all computed the correct result (checked with ForwardDiff), also when calling them multiple times.

So I guess that means that in this specific example there hasn't been a problem in Turing/DynamicPPL and we don't have to change anything?

torfjelde commented 1 year ago

Awesome; thank you @wsmoses !