Closed torfjelde closed 1 year ago
Well this is definitely interesting. 1) can you try Enzyme#main 2) Can you rewrite this to directly call autodiff?
Can you rewrite this to directly call autodiff?
This is a good shout; could be a bug in how LogDensityProblemsAD uses Enzyme?
Wait, did that fix it?
I mean at minimum it lets us try to analyze the Enzyme-specific parts of the bug (and then write a test).
could be a bug in how LogDensityProblemsAD uses Enzyme?
This would be interesting since @wsmoses also contributed to it and improved my initial version :stuck_out_tongue: There's also no cache involved (https://github.com/tpapp/LogDensityProblemsAD.jl/blob/e184a8ed5bca9c90b28b335de291b8390f868789/ext/LogDensityProblemsADEnzymeExt.jl#L70-L75), so I assume it's unlikely that the interface in LogDensityProblemsAD is the main culprit. (BTW already opened a PR for upcoming features in Enzyme 0.11: https://github.com/tpapp/LogDensityProblemsAD.jl/pull/10)
Wait, did that fix it?
Ah no, I haven't had the chance to try yet. Just thought it could be possible that it had something to do with this. But given @devmotion 's comment it seems less likely.
Again, haven't had the time to try it yet though.
- can you try Enzyme#main
I can reproduce the example with Enzyme#main + EnzymeCore#main.
- Can you rewrite this to directly call autodiff?
The example can be reproduced also with autodiff
(tested with Enzyme#main + EnzymeCore#main):
julia> ... # the definitions above
julia> LogDensityProblems.logdensity_and_gradient(f, [0.0, 1.0])
(-34.79163260301861, [7.0148317141196355, -28.486511693195027])
julia> LogDensityProblems.logdensity_and_gradient(f, [0.0, 1.0])
(-34.79163260301861, [-7.0148317141196355, -20.042228715930044])
julia> logdensity = Base.Fix1(LogDensityProblems.logdensity, DynamicPPL.LogDensityFunction(cmodel, SimpleVarInfo((α = 0.0, β = 1.0))));
julia> logdensity([0.0, 1.0])
-34.79163260301861
julia> autodiff(Enzyme.Reverse, logdensity, Duplicated([0.0, 1.0], Δx))
((nothing,),)
julia> Δx
2-element Vector{Float64}:
7.0148317141196355
-28.486511693195027
julia> Δx = zeros(2)
2-element Vector{Float64}:
0.0
0.0
julia> autodiff(Enzyme.Reverse, logdensity, Duplicated([0.0, 1.0], Δx))
((nothing,),)
julia> Δx
2-element Vector{Float64}:
-7.0148317141196355
-20.042228715930044
It's possible to eliminate LogDensityProblems
completely but currently I have to restart + rerun since Julia segfaulted when I tried to use ReverseWithPrimal
:
julia> autodiff(Enzyme.ReverseWithPrimal, logdensity, Duplicated([0.0, 1.0], Δx))
rep: %56 = bitcast {}* %35 to { { [2 x double], double }, {} addrspace(10)* }* prev: %57 = bitcast {} addrspace(10)* %37 to { { [2 x double], double }, {} addrspace(10)* } addrspace(10)*, !enzyme_caststack !4 inst: store { { [2 x double], double }, {} addrspace(10)* } addrspace(10)* %57, { { [2 x double], double }, {} addrspace(10)* } addrspace(10)** %58, align 8
Illegal address space propagation
UNREACHABLE executed at /workspace/srcdir/Enzyme/enzyme/Enzyme/FunctionUtils.cpp:406!
signal (6): Aborted
in expression starting at REPL[25]:1
__pthread_kill_implementation at /lib64/libc.so.6 (unknown line)
gsignal at /lib64/libc.so.6 (unknown line)
abort at /lib64/libc.so.6 (unknown line)
_ZN4llvm25llvm_unreachable_internalEPKcS1_j at /home/david/.julia/juliaup/julia-1.8.5+0.x64.linux.gnu/bin/../lib/julia/libLLVM-13jl.so (unknown line)
RecursivelyReplaceAddressSpace at /workspace/srcdir/Enzyme/enzyme/Enzyme/FunctionUtils.cpp:406
LowerAllocAddr at /workspace/srcdir/Enzyme/enzyme/Enzyme/FunctionUtils.cpp:722
CreateAugmentedPrimal at /workspace/srcdir/Enzyme/enzyme/Enzyme/EnzymeLogic.cpp:2648
visitCallInst at /workspace/srcdir/Enzyme/enzyme/Enzyme/AdjointGenerator.h:12137
delegateCallInst at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/InstVisitor.h:302 [inlined]
visitCall at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/Instruction.def:209 [inlined]
visit at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/Instruction.def:209
visit at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/InstVisitor.h:112 [inlined]
CreatePrimalAndGradient at /workspace/srcdir/Enzyme/enzyme/Enzyme/EnzymeLogic.cpp:3938
visitCallInst at /workspace/srcdir/Enzyme/enzyme/Enzyme/AdjointGenerator.h:12497
delegateCallInst at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/InstVisitor.h:302 [inlined]
visitCall at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/Instruction.def:209 [inlined]
visit at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/Instruction.def:209
visit at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/InstVisitor.h:112 [inlined]
CreatePrimalAndGradient at /workspace/srcdir/Enzyme/enzyme/Enzyme/EnzymeLogic.cpp:3938
EnzymeCreatePrimalAndGradient at /workspace/srcdir/Enzyme/enzyme/Enzyme/CApi.cpp:470
EnzymeCreatePrimalAndGradient at /home/david/.julia/packages/Enzyme/PjX8s/src/api.jl:123
enzyme! at /home/david/.julia/packages/Enzyme/PjX8s/src/compiler.jl:6324
unknown function (ip: 0x7f17f4ba1e83)
_jl_invoke at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/gf.c:2377 [inlined]
ijl_apply_generic at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/gf.c:2559
#codegen#134 at /home/david/.julia/packages/Enzyme/PjX8s/src/compiler.jl:7578
codegen##kw at /home/david/.julia/packages/Enzyme/PjX8s/src/compiler.jl:7208 [inlined]
_thunk at /home/david/.julia/packages/Enzyme/PjX8s/src/compiler.jl:8090 [inlined]
_thunk at /home/david/.julia/packages/Enzyme/PjX8s/src/compiler.jl:8084
unknown function (ip: 0x7f17f4b9cc3d)
_jl_invoke at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/gf.c:2377 [inlined]
ijl_apply_generic at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/gf.c:2559
cached_compilation at /home/david/.julia/packages/Enzyme/PjX8s/src/compiler.jl:8128
unknown function (ip: 0x7f17f4f4cef3)
_jl_invoke at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/gf.c:2377 [inlined]
ijl_apply_generic at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/gf.c:2559
#s451#163 at /home/david/.julia/packages/Enzyme/PjX8s/src/compiler.jl:8188 [inlined]
#s451#163 at ./none:0
_jl_invoke at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/gf.c:2377 [inlined]
ijl_apply_generic at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/gf.c:2559
GeneratedFunctionStub at ./boot.jl:582
_jl_invoke at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/gf.c:2377 [inlined]
ijl_apply_generic at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/gf.c:2559
jl_apply at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/julia.h:1843 [inlined]
jl_call_staged at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/method.c:520
ijl_code_for_staged at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/method.c:571
get_staged at ./compiler/utilities.jl:114
retrieve_code_info at ./compiler/utilities.jl:126 [inlined]
InferenceState at ./compiler/inferencestate.jl:284
typeinf_edge at ./compiler/typeinfer.jl:868
abstract_call_method at ./compiler/abstractinterpretation.jl:647
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:139
abstract_call_known at ./compiler/abstractinterpretation.jl:1716
abstract_call at ./compiler/abstractinterpretation.jl:1786
abstract_call at ./compiler/abstractinterpretation.jl:1753
abstract_eval_statement at ./compiler/abstractinterpretation.jl:1910
typeinf_local at ./compiler/abstractinterpretation.jl:2386
typeinf_nocycle at ./compiler/abstractinterpretation.jl:2482
_typeinf at ./compiler/typeinfer.jl:230
typeinf at ./compiler/typeinfer.jl:213
typeinf_edge at ./compiler/typeinfer.jl:877
abstract_call_method at ./compiler/abstractinterpretation.jl:647
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:139
abstract_call_known at ./compiler/abstractinterpretation.jl:1716
abstract_call at ./compiler/abstractinterpretation.jl:1786
abstract_call at ./compiler/abstractinterpretation.jl:1753
abstract_eval_statement at ./compiler/abstractinterpretation.jl:1910
typeinf_local at ./compiler/abstractinterpretation.jl:2386
typeinf_nocycle at ./compiler/abstractinterpretation.jl:2482
_typeinf at ./compiler/typeinfer.jl:230
typeinf at ./compiler/typeinfer.jl:213
typeinf_edge at ./compiler/typeinfer.jl:877
abstract_call_method at ./compiler/abstractinterpretation.jl:647
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:139
abstract_call_known at ./compiler/abstractinterpretation.jl:1716
abstract_call at ./compiler/abstractinterpretation.jl:1786
abstract_call at ./compiler/abstractinterpretation.jl:1753
abstract_eval_statement at ./compiler/abstractinterpretation.jl:1910
typeinf_local at ./compiler/abstractinterpretation.jl:2360
typeinf_nocycle at ./compiler/abstractinterpretation.jl:2482
_typeinf at ./compiler/typeinfer.jl:230
typeinf at ./compiler/typeinfer.jl:213
typeinf_edge at ./compiler/typeinfer.jl:877
abstract_call_method at ./compiler/abstractinterpretation.jl:647
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:139
abstract_call_known at ./compiler/abstractinterpretation.jl:1716
abstract_call at ./compiler/abstractinterpretation.jl:1786
abstract_apply at ./compiler/abstractinterpretation.jl:1357
abstract_call_known at ./compiler/abstractinterpretation.jl:1620
abstract_call at ./compiler/abstractinterpretation.jl:1786
abstract_call at ./compiler/abstractinterpretation.jl:1753
abstract_eval_statement at ./compiler/abstractinterpretation.jl:1910
typeinf_local at ./compiler/abstractinterpretation.jl:2386
typeinf_nocycle at ./compiler/abstractinterpretation.jl:2482
_typeinf at ./compiler/typeinfer.jl:230
typeinf at ./compiler/typeinfer.jl:213
typeinf_ext at ./compiler/typeinfer.jl:967
typeinf_ext_toplevel at ./compiler/typeinfer.jl:1000
typeinf_ext_toplevel at ./compiler/typeinfer.jl:996
jfptr_typeinf_ext_toplevel_17539.clone_1 at /home/david/.julia/juliaup/julia-1.8.5+0.x64.linux.gnu/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/gf.c:2377 [inlined]
ijl_apply_generic at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/gf.c:2559
jl_apply at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/julia.h:1843 [inlined]
jl_type_infer at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/gf.c:315
jl_generate_fptr_impl at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/jitlayers.cpp:319
jl_compile_method_internal at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/gf.c:2091 [inlined]
jl_compile_method_internal at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/gf.c:2035
_jl_invoke at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/gf.c:2369 [inlined]
ijl_apply_generic at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/gf.c:2559
jl_apply at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/julia.h:1843 [inlined]
do_call at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/interpreter.c:126
eval_value at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/interpreter.c:215
eval_stmt_value at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/interpreter.c:166 [inlined]
eval_body at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/interpreter.c:612
jl_interpret_toplevel_thunk at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/interpreter.c:750
jl_toplevel_eval_flex at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/toplevel.c:906
jl_toplevel_eval_flex at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/toplevel.c:850
jl_toplevel_eval_flex at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/toplevel.c:850
jl_toplevel_eval_flex at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/toplevel.c:850
ijl_toplevel_eval_in at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/toplevel.c:965
eval at ./boot.jl:368 [inlined]
eval_user_input at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/usr/share/julia/stdlib/v1.8/REPL/src/REPL.jl:151
repl_backend_loop at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/usr/share/julia/stdlib/v1.8/REPL/src/REPL.jl:247
start_repl_backend at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/usr/share/julia/stdlib/v1.8/REPL/src/REPL.jl:232
#run_repl#47 at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/usr/share/julia/stdlib/v1.8/REPL/src/REPL.jl:369
run_repl at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/usr/share/julia/stdlib/v1.8/REPL/src/REPL.jl:355
jfptr_run_repl_65104.clone_1 at /home/david/.julia/juliaup/julia-1.8.5+0.x64.linux.gnu/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/gf.c:2377 [inlined]
ijl_apply_generic at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/gf.c:2559
#967 at ./client.jl:419
jfptr_YY.967_33139.clone_1 at /home/david/.julia/juliaup/julia-1.8.5+0.x64.linux.gnu/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/gf.c:2377 [inlined]
ijl_apply_generic at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/gf.c:2559
jl_apply at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/julia.h:1843 [inlined]
jl_f__call_latest at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/builtins.c:774
#invokelatest#2 at ./essentials.jl:729 [inlined]
invokelatest at ./essentials.jl:726 [inlined]
run_main_repl at ./client.jl:404
exec_options at ./client.jl:318
_start at ./client.jl:522
jfptr__start_38041.clone_1 at /home/david/.julia/juliaup/julia-1.8.5+0.x64.linux.gnu/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/gf.c:2377 [inlined]
ijl_apply_generic at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/gf.c:2559
jl_apply at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/julia.h:1843 [inlined]
true_main at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/jlapi.c:575
jl_repl_entrypoint at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/src/jlapi.c:719
main at /cache/build/default-amdci4-2/julialang/julia-release-1-dot-8/cli/loader_exe.c:59
__libc_start_call_main at /lib64/libc.so.6 (unknown line)
__libc_start_main at /lib64/libc.so.6 (unknown line)
unknown function (ip: 0x401098)
Allocations: 114092184 (Pool: 113988457; Big: 103727); GC: 80
@devmotion I thought we got rid of the Fix1 earlier by passing another (const) argument?
Yeah, that's not part of LogDensityProblems, only in my simpler example. I only remembered that Fix1/Fix2 are not needed in Enzyme when Julia had already segfaulted :smile:
OK, completely without LogDensityProblems:
(jl_6gvLdX) pkg> add Enzyme#main EnzymeCore#main Distributions LogDensityProblems LogDensityProblemsAD
julia> using Distributions, DynamicPPL, LogDensityProblems, LogDensityProblemsAD, Enzyme, LinearAlgebra
julia> @model function linear_regression(x)
α ~ Normal(0, 1)
β ~ Normal(0, 1)
μ = α .+ β * x
y ~ MvNormal(μ, I)
end
linear_regression (generic function with 2 methods)
julia> x = randn(10);
julia> model = linear_regression(x);
julia> cmodel = model | (y = model(),);
julia> logjoint(cmodel, SimpleVarInfo((α = 0.0, β = 1.0)))
-27.305144853104405
julia> f = ADgradient(:Enzyme, DynamicPPL.LogDensityFunction(cmodel, SimpleVarInfo((α = 0.0, β = 1.0))));
julia> LogDensityProblems.logdensity_and_gradient(f, [0.0, 1.0])
(-27.305144853104405, [9.942031055458013, 12.437679194765497])
julia> LogDensityProblems.logdensity_and_gradient(f, [0.0, 1.0]) # sample problem as above
(-27.305144853104405, [-9.942031055458013, -45.993444104062156])
julia> logp(model, α, β) = logjoint(model, SimpleVarInfo((; α, β)))
logp (generic function with 1 method)
julia> autodiff(Enzyme.Reverse, logp, cmodel, Active(0.0), Active(1.0))
((nothing, 9.942031055458013, 12.437679194765499),)
julia> autodiff(Enzyme.Reverse, logp, cmodel, Active(0.0), Active(1.0)) # sample problem, with `autodiff` and w/o LogDensityProblems
((nothing, -9.942031055458013, -45.993444104062156),)
julia> using ForwardDiff
julia> ForwardDiff.gradient(x -> logp(cmodel, x[1], x[2]), [0.0, 1.0])
2-element Vector{Float64}:
9.942031055458015
12.437679194765499
julia> autodiff(Enzyme.ReverseWithPrimal, logp, cmodel, Active(0.0), Active(1.0))
rep: %55 = bitcast {}* %34 to { { [2 x double], double }, {} addrspace(10)* }* prev: %56 = bitcast {} addrspace(10)* %36 to { { [2 x double], double }, {} addrspace(10)* } addrspace(10)*, !enzyme_caststack !4 inst: store { { [2 x double], double }, {} addrspace(10)* } addrspace(10)* %56, { { [2 x double], double }, {} addrspace(10)* } addrspace(10)** %57, align 8
Illegal address space propagation
UNREACHABLE executed at /workspace/srcdir/Enzyme/enzyme/Enzyme/FunctionUtils.cpp:406!
signal (6): Aborted
in expression starting at REPL[25]:1
__pthread_kill_implementation at /lib64/libc.so.6 (unknown line)
gsignal at /lib64/libc.so.6 (unknown line)
abort at /lib64/libc.so.6 (unknown line)
_ZN4llvm25llvm_unreachable_internalEPKcS1_j at /home/david/.julia/juliaup/julia-1.8.5+0.x64.linux.gnu/bin/../lib/julia/libLLVM-13jl.so (unknown line)
RecursivelyReplaceAddressSpace at /workspace/srcdir/Enzyme/enzyme/Enzyme/FunctionUtils.cpp:406
LowerAllocAddr at /workspace/srcdir/Enzyme/enzyme/Enzyme/FunctionUtils.cpp:722
CreateAugmentedPrimal at /workspace/srcdir/Enzyme/enzyme/Enzyme/EnzymeLogic.cpp:2648
visitCallInst at /workspace/s
...
Another fun observation: Forward-mode returns yet another set of derivatives:
julia> ... # as above
julia> logp(cmodel, 0.0, 1.0)
-20.712423408307206
julia> autodiff(Reverse, logp, cmodel, Active(0.0), Active(1.0))
((nothing, 11.908124667958683, -7.800343891290028),)
julia> autodiff(Reverse, logp, cmodel, Active(0.0), Active(1.0)) # problem discussed above
((nothing, -11.908124667958683, -12.56997812841224),)
julia> autodiff(Forward, logp, BatchDuplicated, cmodel, BatchDuplicated(0.0, (1.0, 0.0)), BatchDuplicated(1.0, (0.0, 1.0)))
%"'ipc68" = addrspacecast { [1 x {} addrspace(10)*] }* %14 to { [1 x {} addrspace(10)*] } addrspace(11)*, !dbg !132
%49 = insertvalue [2 x { [1 x {} addrspace(10)*] } addrspace(11)*] undef, { [1 x {} addrspace(10)*] } addrspace(11)* %"'ipc68", 0, !dbg !132
%"'ipc69" = addrspacecast { [1 x {} addrspace(10)*] }* %17 to { [1 x {} addrspace(10)*] } addrspace(11)*, !dbg !132
%52 = insertvalue [2 x { [1 x {} addrspace(10)*] } addrspace(11)*] %51, { [1 x {} addrspace(10)*] } addrspace(11)* %"'ipc69", 1, !dbg !132
%"'ipc30.i" = addrspacecast { [1 x {} addrspace(10)*], { [1 x {} addrspace(10)*] } }* %16 to { [1 x {} addrspace(10)*], { [1 x {} addrspace(10)*] } } addrspace(11)*, !dbg !47
%52 = insertvalue [2 x { [1 x {} addrspace(10)*], { [1 x {} addrspace(10)*] } } addrspace(11)*] %51, { [1 x {} addrspace(10)*], { [1 x {} addrspace(10)*] } } addrspace(11)* %"'ipc30.i", 1, !dbg !47
%"'ipc57.i" = addrspacecast { [1 x {} addrspace(10)*] }* %17 to { [1 x {} addrspace(10)*] } addrspace(11)*, !dbg !187
%69 = insertvalue [2 x { [1 x {} addrspace(10)*] } addrspace(11)*] %68, { [1 x {} addrspace(10)*] } addrspace(11)* %"'ipc57.i", 1, !dbg !187
%"'ipc56.i" = addrspacecast { [1 x {} addrspace(10)*] }* %18 to { [1 x {} addrspace(10)*] } addrspace(11)*, !dbg !187
%69 = insertvalue [2 x { [1 x {} addrspace(10)*] } addrspace(11)*] undef, { [1 x {} addrspace(10)*] } addrspace(11)* %"'ipc56.i", 0, !dbg !187
%"'ipc.i" = addrspacecast { [1 x {} addrspace(10)*], { [1 x {} addrspace(10)*] } }* %19 to { [1 x {} addrspace(10)*], { [1 x {} addrspace(10)*] } } addrspace(11)*, !dbg !47
%54 = insertvalue [2 x { [1 x {} addrspace(10)*], { [1 x {} addrspace(10)*] } } addrspace(11)*] undef, { [1 x {} addrspace(10)*], { [1 x {} addrspace(10)*] } } addrspace(11)* %"'ipc.i", 0, !dbg !47
(-20.712423408307206, (var"1" = -6.462197351743586, var"2" = -26.1706659109923))
@devmotion okay yeah, this is going to need more minimization. In particular, trying to get the AD'd function as small/using as few dependencies as possible will allow us to investigate what/why is going on.
Even simpler than the last examples? They already use the API in DynamicPPL without LogDensityProblems and are based only on input types defined in DynamicPPL.
Yes, e.g. trying to remove the dependency on DynamicPPL, custom macro, and even model ideally, etc.
All fundamental things for the example (@model
, the Model
type itself, SimpleVarInfo
, and the whole logic for logjoint
) are defined in DynamicPPL, so there won't be anything left if it is removed I think :grimacing: It is already the really core condensed part of Turing.
Sorry, I mean try to remove the dependency -- not remove the code. Aka, can you inline the definition of these macros, packages, etc so I can run this code without using and/or understanding what these packages do?
We have to try to minimize the actual number of instructions generated by Julia to find what is going wrong, and if I don't understand what a syntax means or how a package works, I cannot do that effectively.
As a nice thing, it appears the latest Enzyme proper (aka jll ) main actually fixes the reverse mode correctness issue.
Investigating forward mode things.
It involves so many internals that unwrapping is tedious and I do it step by step. But it gets more and more confusing since by unwrapping one step, I get other, completely different derivatives:
julia> ... # cmodel from above
julia> logp(model, α, β) = logjoint(model, SimpleVarInfo((; α, β)))
logp (generic function with 1 method)
julia> logp(cmodel, 0.0, 1.0)
-62.628671993804964
julia> autodiff(Enzyme.Reverse, logp, cmodel, Active(0.0), Active(1.0))
((nothing, -16.547138600446626, -28.39510526890181),)
julia> autodiff(Enzyme.Reverse, logp, cmodel, Active(0.0), Active(1.0))
((nothing, 16.547138600446626, -75.80771392179598),)
julia> logp2(model, α, β) = last(DynamicPPL._evaluate!!(model, SimpleVarInfo((; α, β)), DefaultContext())).logp
logp2 (generic function with 1 method)
julia> logp2(cmodel, 0.0, 1.0)
-62.628671993804964
julia> autodiff(Enzyme.Reverse, logp2, cmodel, Active(0.0), Active(1.0))
((nothing, -16.547138600446626, -28.395105268901798),)
julia> autodiff(Enzyme.Reverse, logp2, cmodel, Active(0.0), Active(1.0))
((nothing, 1.0824674490095276e-15, -0.9999999999999989),)
julia> autodiff(Enzyme.Reverse, logp2, cmodel, Active(0.0), Active(1.0))
((nothing, 0.0, -1.0),)
julia> autodiff(Enzyme.Reverse, logp2, cmodel, Active(0.0), Active(1.0))
((nothing, 0.0, -1.0),)
At this point, ignore the Reverse mode issue, since I think that is fixed as soon as we bump the jll. Minimizing the LLVM-level ReverseWithPrimal error, however, is most helpful.
edit: changed to say ReverseWithPrimal**
And that doesn't matter what it prints out, so long as it causes the LLVM-level error.
The segfault was caused by ReverseWithPrimal
, or are you referring to the % ...
lines in the output?
I also ran forward mode and compared logp
and logp2
from the comment above:
julia> autodiff(Forward, logp, BatchDuplicated, cmodel, BatchDuplicated(0.0, (1.0, 0.0)), BatchDuplicated(1.0, (0.0, 1.0)))
%"'ipc68" = addrspacecast { [1 x {} addrspace(10)*] }* %14 to { [1 x {} addrspace(10)*] } addrspace(11)*, !dbg !132
%49 = insertvalue [2 x { [1 x {} addrspace(10)*] } addrspace(11)*] undef, { [1 x {} addrspace(10)*] } addrspace(11)* %"'ipc68", 0, !dbg !132
%"'ipc69" = addrspacecast { [1 x {} addrspace(10)*] }* %17 to { [1 x {} addrspace(10)*] } addrspace(11)*, !dbg !132
%52 = insertvalue [2 x { [1 x {} addrspace(10)*] } addrspace(11)*] %51, { [1 x {} addrspace(10)*] } addrspace(11)* %"'ipc69", 1, !dbg !132
%"'ipc57.i" = addrspacecast { [1 x {} addrspace(10)*] }* %16 to { [1 x {} addrspace(10)*] } addrspace(11)*, !dbg !187
%68 = insertvalue [2 x { [1 x {} addrspace(10)*] } addrspace(11)*] %67, { [1 x {} addrspace(10)*] } addrspace(11)* %"'ipc57.i", 1, !dbg !187
%"'ipc56.i" = addrspacecast { [1 x {} addrspace(10)*] }* %17 to { [1 x {} addrspace(10)*] } addrspace(11)*, !dbg !187
%68 = insertvalue [2 x { [1 x {} addrspace(10)*] } addrspace(11)*] undef, { [1 x {} addrspace(10)*] } addrspace(11)* %"'ipc56.i", 0, !dbg !187
%"'ipc30.i" = addrspacecast { [1 x {} addrspace(10)*], { [1 x {} addrspace(10)*] } }* %18 to { [1 x {} addrspace(10)*], { [1 x {} addrspace(10)*] } } addrspace(11)*, !dbg !47
%54 = insertvalue [2 x { [1 x {} addrspace(10)*], { [1 x {} addrspace(10)*] } } addrspace(11)*] %53, { [1 x {} addrspace(10)*], { [1 x {} addrspace(10)*] } } addrspace(11)* %"'ipc30.i", 1, !dbg !47
%"'ipc.i" = addrspacecast { [1 x {} addrspace(10)*], { [1 x {} addrspace(10)*] } }* %19 to { [1 x {} addrspace(10)*], { [1 x {} addrspace(10)*] } } addrspace(11)*, !dbg !47
%54 = insertvalue [2 x { [1 x {} addrspace(10)*], { [1 x {} addrspace(10)*] } } addrspace(11)*] undef, { [1 x {} addrspace(10)*], { [1 x {} addrspace(10)*] } } addrspace(11)* %"'ipc.i", 0, !dbg !47
(-11.527262398456074, (var"1" = 0.0, var"2" = -1.0))
julia> autodiff(Forward, logp2, BatchDuplicated, cmodel, BatchDuplicated(0.0, (1.0, 0.0)), BatchDuplicated(1.0, (0.0, 1.0)))
%"'ipc27.i" = addrspacecast { [1 x {} addrspace(10)*] }* %11 to { [1 x {} addrspace(10)*] } addrspace(11)*, !dbg !83
%36 = insertvalue [2 x { [1 x {} addrspace(10)*] } addrspace(11)*] %35, { [1 x {} addrspace(10)*] } addrspace(11)* %"'ipc27.i", 1, !dbg !83
%"'ipc26.i" = addrspacecast { [1 x {} addrspace(10)*] }* %12 to { [1 x {} addrspace(10)*] } addrspace(11)*, !dbg !83
%36 = insertvalue [2 x { [1 x {} addrspace(10)*] } addrspace(11)*] undef, { [1 x {} addrspace(10)*] } addrspace(11)* %"'ipc26.i", 0, !dbg !83
(-11.527262398456074, (var"1" = 0.0, var"2" = -1.0))
So it seems at least they are consistent - and they match the result that the reverse-mode derivatives of logp2
"converge" to?
No sorry, I meant to say the ReverseWithPrimal. As you can see this is a new (not yet released) feature, hence needing some more love/testing to get all the errors back to Julia-level.
So it seems at least they are consistent - and they match the result that the reverse-mode derivatives of logp2 "converge" to?
And they agree with ForwardDiff, it seems:
julia> ForwardDiff.gradient(x -> logp(cmodel, x[1], x[2]), [0.0, 1.0])
2-element Vector{Float64}:
0.0
-1.0
julia> ForwardDiff.gradient(x -> logp2(cmodel, x[1], x[2]), [0.0, 1.0])
2-element Vector{Float64}:
0.0
-1.0
It is weird that forward is printing though.... That just seems to be a println that wasn't removed (and we also should separately try to minimize and find).
:slightly_smiling_face: I think the problem with ReverseWithPrimal can be found by looking at what changed when simplifying logp
to logp2
since it works for logp2
:
julia> autodiff(ReverseWithPrimal, logp2, cmodel, Active(0.0), Active(1.0))
((nothing, 0.0, -1.0), -11.527262398456074)
julia> logp2(cmodel, 0.0, 1.0)
-11.527262398456074
Interesting observation as well is that apparently the "convergence" observed above (everything still in the same Julia session) was caused by some unintended mutation of the model. So it seems running autodiff
modified some of the values in the model even though it shouldn't?
I'll try to confirm this with a newly sampled model instance.
OK, so ReverseWithPrimal
seems to work with logp2
consistently and shows the same mutating behaviour:
julia> x = randn(10);
julia> model = linear_regression(x);
julia> cmodel = model | (y = model(),);
julia> logp(cmodel, 0.0, 1.0)
-25.473371825268227
julia> logp2(cmodel, 0.0, 1.0)
-25.473371825268227
julia> autodiff(ReverseWithPrimal, logp2, cmodel, Active(0.0), Active(1.0))
((nothing, -12.193412883437983, -9.213664141918544), -25.473371825268227)
julia> autodiff(ReverseWithPrimal, logp2, cmodel, Active(0.0), Active(1.0))
((nothing, -1.1102230246251565e-16, -1.0000000000000002), -11.527262398456074)
julia> autodiff(ReverseWithPrimal, logp2, cmodel, Active(0.0), Active(1.0))
((nothing, 0.0, -1.0), -11.527262398456074)
Even more evidence that the different behaviour is caused by differences between DynamicPPL.evaluate_threadsafe!!
(used by logjoint
on my computer since I use multiple threads) and DynamicPPL.evaluate_threadunsafe!!
(calls _evaluate!!
in the logp2
definition above):
julia> ... # cmodel, logp, and logp2 as above
julia> # this is called by `logjoint` (on my computer), so behaviour should be the same as for `logp`
logp3(model, α, β) = last(DynamicPPL.evaluate_threadsafe!!(model, SimpleVarInfo((; α, β)), DefaultContext())).logp
julia> # `evaluate_threadunsafe!!` calls `_evaluate!!`, so behaviour should be same as for `logp2`
logp4(model, α, β) = last(DynamicPPL.evaluate_threadunsafe!!(model, SimpleVarInfo((; α, β)), DefaultContext())).logp
julia> logp(cmodel, 0.0, 1.0)
-60.156275507914806
julia> logp2(cmodel, 0.0, 1.0)
-60.156275507914806
julia> logp3(cmodel, 0.0, 1.0)
-60.156275507914806
julia> logp4(cmodel, 0.0, 1.0)
-60.156275507914806
julia> autodiff(Reverse, logp, cmodel, Active(0.0), Active(1.0))
((nothing, 26.83640954931467, -30.79381366952238),)
julia> autodiff(Reverse, logp, cmodel, Active(0.0), Active(1.0))
((nothing, -26.83640954931467, -68.46421254939509),)
julia> autodiff(Reverse, logp3, cmodel, Active(0.0), Active(1.0))
((nothing, 26.83640954931467, -30.79381366952238),)
julia> autodiff(Reverse, logp3, cmodel, Active(0.0), Active(1.0)) # same as logp
((nothing, -26.83640954931467, -68.46421254939509),)
julia> autodiff(Reverse, logp3, cmodel, Active(0.0), Active(1.0))
((nothing, 26.83640954931467, -30.79381366952238),)
julia> autodiff(Reverse, logp4, cmodel, Active(0.0), Active(1.0)) # same as logp2
((nothing, -26.83640954931467, -68.46421254939509),)
julia> autodiff(Reverse, logp4, cmodel, Active(0.0), Active(1.0))
((nothing, -2.220446049250313e-16, -1.0000000000000009),)
The functions are defined https://github.com/TuringLang/DynamicPPL.jl/blob/b23acff013a9111c8ce2c89dbf5339e76234d120/src/model.jl#L550 and https://github.com/TuringLang/DynamicPPL.jl/blob/b23acff013a9111c8ce2c89dbf5339e76234d120/src/model.jl#L565
Very roughly, the main difference is that evaluate_threadsafe!!
uses a wrapper data structure that has very basic support for threaded log density evaluation by allowing users to add @threads for ...
in their model. The design is not good (anymore): https://github.com/TuringLang/DynamicPPL.jl/issues/429 Could this use of threadid
or some other Threads-related thing explain the differences in the results and why ReverseWithPrimal causes a segfault?
Both the correctness and LLVM-level errors should now be fixed on latest main. Please reopen if it persists.
Great, thank you! I checked logp
and logp2
(which should cover both branches), and Reverse
, Forward
, and ReverseWithPrimal
all computed the correct result (checked with ForwardDiff), also when calling them multiple times.
So I guess that means that in this specific example there hasn't been a problem in Turing/DynamicPPL and we don't have to change anything?
Awesome; thank you @wsmoses !
So this is quite a fun one:
One of these is correct, as can be seen by using ForwardDiff:
It's all very interesting :thinking: