EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
461 stars 66 forks source link

Nested AD Errors Out #2147

Open turiya4 opened 4 days ago

turiya4 commented 4 days ago

The following code (using Julia 1.10.6 and Enzyme 13.17) does simple Nested AD. However this results in an error.

using Enzyme, Lux, Random, ComponentArrays, LinearAlgebra
n = 1
x_batch = randn(2, n)
y_batch = randn(2, n)
model = Chain(Parallel(vcat, Dense(2, 1, tanh), Dense(2, 1, tanh)), Dense(2, 1, tanh))
rng = Random.default_rng()
Random.seed!(rng, 0)
ps, st = Lux.setup(Xoshiro(0), model);
psaxes = getaxes(ComponentArray(ps))

nnfunc(x, y, psarray, st) = first(model((x, y), ComponentArray(psarray, psaxes), st))[1]
psarray = getdata(ComponentArray(ps))

function batch_error(xb, yb, psarray, st)
    val = zeros(n)
    for i = 1 : n
        dx = zeros(2)
        Enzyme.autodiff(Enzyme.Reverse, nnfunc, Active, Duplicated(xb[:, i], dx), Duplicated(yb[:, i], zeros(2)), Duplicated(psarray, zeros(Float32, size(psarray))), Const(st))
        val[i] = sum(dx.^2)
    end
    return sum(val)
end

dpsarray = zeros(Float32, size(psarray))
Enzyme.autodiff(Enzyme.Reverse, batch_error, Active, Duplicated(x_batch,zeros(size(x_batch))), Duplicated(y_batch,zeros(size(y_batch))), Duplicated(psarray, dpsarray), Const(st))
ERROR: LoadError: Enzyme compilation failed.
Current scope: 
; Function Attrs: mustprogress willreturn
define "enzyme_type"="{[-1]:Pointer}" "enzymejl_parmtype"="123305751457168" "enzymejl_parmtype_ref"="1" [3 x {} addrspace(10)*] @preprocess_julia_runtime_generic_augfwd_4730_inner.1({} addrspace(10)* nocapture nofree noundef nonnull readnone "enzyme_inactive" "enzyme_type"="{[-1]:Pointer}" "enzymejl_parmtype"="123305469593088" "enzymejl_parmtype_ref"="2" %0, {} addrspace(10)* noundef nonnull align 16 dereferenceable(40) "enzyme_type"="{[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@float, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer}" "enzymejl_parmtype"="123305771723920" "enzymejl_parmtype_ref"="2" %1, {} addrspace(10)* noundef nonnull align 16 dereferenceable(40) "enzyme_type"="{[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@float, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer}" "enzymejl_parmtype"="123305771723920" "enzymejl_parmtype_ref"="2" %2) local_unnamed_addr #5 !dbg !47 {
entry:
  %3 = call {}*** @julia.get_pgcstack() #9, !noalias !48
  %current_task1.i6 = getelementptr inbounds {}**, {}*** %3, i64 -14
  %current_task1.i = bitcast {}*** %current_task1.i6 to {}**
  %ptls_field.i7 = getelementptr inbounds {}**, {}*** %3, i64 2
  %4 = bitcast {}*** %ptls_field.i7 to i64***
  %ptls_load.i89 = load i64**, i64*** %4, align 8, !tbaa !11, !noalias !48
  %5 = getelementptr inbounds i64*, i64** %ptls_load.i89, i64 2
  %safepoint.i = load i64*, i64** %5, align 8, !tbaa !15, !noalias !48
  fence syncscope("singlethread") seq_cst
  call void @julia.safepoint(i64* %safepoint.i) #9, !dbg !51, !noalias !48
  fence syncscope("singlethread") seq_cst
  %6 = call { { {} addrspace(10)* }, { {} addrspace(10)* } } inttoptr (i64 123305430303696 to { { {} addrspace(10)* }, { {} addrspace(10)* } } ({} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*)*)({} addrspace(10)* addrspacecast ({}* inttoptr (i64 123305469593184 to {}*) to {} addrspace(10)*), {} addrspace(10)* nonnull %1, {} addrspace(10)* nonnull %2) #9, !dbg !53
  %7 = extractvalue { { {} addrspace(10)* }, { {} addrspace(10)* } } %6, 0, !dbg !57
  %8 = extractvalue { { {} addrspace(10)* }, { {} addrspace(10)* } } %6, 1, !dbg !57
  %box.i = call noalias nonnull dereferenceable(8) "enzyme_type"="{[-1]:Pointer, [-1,-1]:Pointer, [-1,-1,0]:Pointer, [-1,-1,0,-1]:Float@float, [-1,-1,8]:Integer, [-1,-1,9]:Integer, [-1,-1,10]:Integer, [-1,-1,11]:Integer, [-1,-1,12]:Integer, [-1,-1,13]:Integer, [-1,-1,14]:Integer, [-1,-1,15]:Integer, [-1,-1,16]:Integer, [-1,-1,17]:Integer, [-1,-1,18]:Integer, [-1,-1,19]:Integer, [-1,-1,20]:Integer, [-1,-1,21]:Integer, [-1,-1,22]:Integer, [-1,-1,23]:Integer, [-1,-1,24]:Integer, [-1,-1,25]:Integer, [-1,-1,26]:Integer, [-1,-1,27]:Integer, [-1,-1,28]:Integer, [-1,-1,29]:Integer, [-1,-1,30]:Integer, [-1,-1,31]:Integer, [-1,-1,32]:Integer, [-1,-1,33]:Integer, [-1,-1,34]:Integer, [-1,-1,35]:Integer, [-1,-1,36]:Integer, [-1,-1,37]:Integer, [-1,-1,38]:Integer, [-1,-1,39]:Integer}" {} addrspace(10)* @julia.gc_alloc_obj({}** nonnull %current_task1.i, i64 noundef 8, {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 123305780515024 to {}*) to {} addrspace(10)*)) #10, !dbg !58
  %9 = bitcast {} addrspace(10)* %box.i to { {} addrspace(10)* } addrspace(10)*, !dbg !58
  %10 = extractvalue { {} addrspace(10)* } %7, 0, !dbg !58
  %11 = getelementptr { {} addrspace(10)* }, { {} addrspace(10)* } addrspace(10)* %9, i64 0, i32 0, !dbg !58
  store {} addrspace(10)* %10, {} addrspace(10)* addrspace(10)* %11, align 8, !dbg !58, !tbaa !32, !alias.scope !36, !noalias !60
  %box4.i = call noalias nonnull dereferenceable(8) "enzyme_type"="{[-1]:Pointer, [-1,-1]:Pointer, [-1,-1,0]:Pointer, [-1,-1,0,-1]:Float@float, [-1,-1,8]:Integer, [-1,-1,9]:Integer, [-1,-1,10]:Integer, [-1,-1,11]:Integer, [-1,-1,12]:Integer, [-1,-1,13]:Integer, [-1,-1,14]:Integer, [-1,-1,15]:Integer, [-1,-1,16]:Integer, [-1,-1,17]:Integer, [-1,-1,18]:Integer, [-1,-1,19]:Integer, [-1,-1,20]:Integer, [-1,-1,21]:Integer, [-1,-1,22]:Integer, [-1,-1,23]:Integer, [-1,-1,24]:Integer, [-1,-1,25]:Integer, [-1,-1,26]:Integer, [-1,-1,27]:Integer, [-1,-1,28]:Integer, [-1,-1,29]:Integer, [-1,-1,30]:Integer, [-1,-1,31]:Integer, [-1,-1,32]:Integer, [-1,-1,33]:Integer, [-1,-1,34]:Integer, [-1,-1,35]:Integer, [-1,-1,36]:Integer, [-1,-1,37]:Integer, [-1,-1,38]:Integer, [-1,-1,39]:Integer}" {} addrspace(10)* @julia.gc_alloc_obj({}** nonnull %current_task1.i, i64 noundef 8, {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 123305780515024 to {}*) to {} addrspace(10)*)) #10, !dbg !58
  %12 = bitcast {} addrspace(10)* %box4.i to { {} addrspace(10)* } addrspace(10)*, !dbg !58
  %13 = extractvalue { {} addrspace(10)* } %8, 0, !dbg !58
  %14 = getelementptr { {} addrspace(10)* }, { {} addrspace(10)* } addrspace(10)* %12, i64 0, i32 0, !dbg !58
  store {} addrspace(10)* %13, {} addrspace(10)* addrspace(10)* %14, align 8, !dbg !58, !tbaa !32, !alias.scope !36, !noalias !60
  %.fca.0.insert = insertvalue [3 x {} addrspace(10)*] poison, {} addrspace(10)* %box.i, 0, !dbg !63
  %.fca.1.insert = insertvalue [3 x {} addrspace(10)*] %.fca.0.insert, {} addrspace(10)* %box4.i, 1, !dbg !63
  %.fca.2.insert = insertvalue [3 x {} addrspace(10)*] %.fca.1.insert, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 123304558811824 to {}*) to {} addrspace(10)*), 2, !dbg !63
  ret [3 x {} addrspace(10)*] %.fca.2.insert, !dbg !63
}

Did not have return index set when differentiating function
 call  %6 = call { { {} addrspace(10)* }, { {} addrspace(10)* } } inttoptr (i64 123305430303696 to { { {} addrspace(10)* }, { {} addrspace(10)* } } ({} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*)*)({} addrspace(10)* addrspacecast ({}* inttoptr (i64 123305469593184 to {}*) to {} addrspace(10)*), {} addrspace(10)* nonnull %1, {} addrspace(10)* nonnull %2) #9, !dbg !19
 augmentcall  %_augmented = call { i8*, { { {} addrspace(10)* }, { {} addrspace(10)* } } } %15({} addrspace(10)* addrspacecast ({}* inttoptr (i64 123305469593184 to {}*) to {} addrspace(10)*), {} addrspace(10)* addrspacecast ({}* inttoptr (i64 123305469593184 to {}*) to {} addrspace(10)*), {} addrspace(10)* %1, {} addrspace(10)* %"'", {} addrspace(10)* %2, {} addrspace(10)* %"'1"), !dbg !19

Stacktrace:
 [1] macro expansion
   @ ~/.julia/packages/Enzyme/fpA3W/src/compiler.jl:6229
 [2] enzyme_call
   @ ~/.julia/packages/Enzyme/fpA3W/src/compiler.jl:5775
 [3] AugmentedForwardThunk
   @ ~/.julia/packages/Enzyme/fpA3W/src/compiler.jl:5697
 [4] runtime_generic_augfwd
   @ ~/.julia/packages/Enzyme/fpA3W/src/rules/jitrules.jl:480
 [5] runtime_generic_augfwd
   @ ~/.julia/packages/Enzyme/fpA3W/src/rules/jitrules.jl:0

Stacktrace:
  [1] julia_error(msg::String, val::Ptr{LLVM.API.LLVMOpaqueValue}, errtype::Enzyme.API.ErrorType, data::Ptr{Nothing}, data2::Ptr{LLVM.API.LLVMOpaqueValue}, B::Ptr{LLVM.API.LLVMOpaqueBuilder})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/fpA3W/src/errors.jl:242
  [2] julia_error(cstr::Cstring, val::Ptr{LLVM.API.LLVMOpaqueValue}, errtype::Enzyme.API.ErrorType, data::Ptr{Nothing}, data2::Ptr{LLVM.API.LLVMOpaqueValue}, B::Ptr{LLVM.API.LLVMOpaqueBuilder})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/fpA3W/src/errors.jl:97
  [3] EnzymeCreateAugmentedPrimal(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnUsed::Bool, shadowReturnUsed::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{…}, forceAnonymousTape::Bool, runtimeActivity::Bool, width::Int64, atomicAdd::Bool)
    @ Enzyme.API ~/.julia/packages/Enzyme/fpA3W/src/api.jl:389
  [4] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::Tuple{…} where N, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/fpA3W/src/compiler.jl:2145
  [5] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/fpA3W/src/compiler.jl:5426
  [6] codegen
    @ ~/.julia/packages/Enzyme/fpA3W/src/compiler.jl:4196 [inlined]
  [7] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/fpA3W/src/compiler.jl:6298
  [8] _thunk
    @ ~/.julia/packages/Enzyme/fpA3W/src/compiler.jl:6298 [inlined]
  [9] cached_compilation
    @ ~/.julia/packages/Enzyme/fpA3W/src/compiler.jl:6339 [inlined]
 [10] thunkbase(mi::Core.MethodInstance, World::UInt64, FA::Type{…}, A::Type{…}, TT::Type, Mode::Enzyme.API.CDerivativeMode, width::Int64, ModifiedBetween::Tuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/fpA3W/src/compiler.jl:6452
 [11] thunk_generator(world::UInt64, source::LineNumberNode, FA::Type, A::Type, TT::Type, Mode::Enzyme.API.CDerivativeMode, Width::Int64, ModifiedBetween::Tuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, self::Any, fakeworld::Any, fa::Type, a::Type, tt::Type, mode::Type, width::Type, modifiedbetween::Type, returnprimal::Type, shadowinit::Type, abi::Type, erriffuncwritten::Type, runtimeactivity::Type)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/fpA3W/src/compiler.jl:6604
 [12] runtime_generic_augfwd(activity::Type{…}, runtimeActivity::Val{…}, width::Val{…}, ModifiedBetween::Val{…}, RT::Val{…}, f::typeof(Enzyme.Compiler.runtime_generic_augfwd), df::Nothing, primal_1::Type{…}, shadow_1_1::Nothing, primal_2::Val{…}, shadow_2_1::Nothing, primal_3::Val{…}, shadow_3_1::Nothing, primal_4::Val{…}, shadow_4_1::Nothing, primal_5::Val{…}, shadow_5_1::Nothing, primal_6::Type{…}, shadow_6_1::Nothing, primal_7::Nothing, shadow_7_1::Nothing, primal_8::Vector{…}, shadow_8_1::Vector{…}, primal_9::Vector{…}, shadow_9_1::Vector{…}, primal_10::Tuple{…}, shadow_10_1::Nothing, primal_11::Nothing, shadow_11_1::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/fpA3W/src/rules/jitrules.jl:465
 [13] nnfunc
    @ ~/temp/test.jl:11 [inlined]
 [14] augmented_julia_nnfunc_2777wrap
    @ ~/temp/test.jl:0
 [15] macro expansion
    @ ~/.julia/packages/Enzyme/fpA3W/src/compiler.jl:6229 [inlined]
 [16] enzyme_call
    @ ~/.julia/packages/Enzyme/fpA3W/src/compiler.jl:5775 [inlined]
 [17] AugmentedForwardThunk
    @ ~/.julia/packages/Enzyme/fpA3W/src/compiler.jl:5697 [inlined]
 [18] autodiff_deferred
    @ ~/.julia/packages/Enzyme/fpA3W/src/Enzyme.jl:729 [inlined]
 [19] autodiff
    @ ~/.julia/packages/Enzyme/fpA3W/src/Enzyme.jl:524 [inlined]
 [20] batch_error
    @ ~/temp/test.jl:18 [inlined]
 [21] augmented_julia_batch_error_2348wrap
    @ ~/temp/test.jl:0
 [22] macro expansion
    @ ~/.julia/packages/Enzyme/fpA3W/src/compiler.jl:6229 [inlined]
 [23] enzyme_call
    @ ~/.julia/packages/Enzyme/fpA3W/src/compiler.jl:5775 [inlined]
 [24] AugmentedForwardThunk
    @ ~/.julia/packages/Enzyme/fpA3W/src/compiler.jl:5697 [inlined]
 [25] autodiff
    @ ~/.julia/packages/Enzyme/fpA3W/src/Enzyme.jl:396 [inlined]
 [26] autodiff(::ReverseMode{false, false, FFIABI, false, false}, ::typeof(batch_error), ::Type{Active}, ::Duplicated{Matrix{…}}, ::Duplicated{Matrix{…}}, ::Duplicated{Vector{…}}, ::Const{@NamedTuple{…}})
    @ Enzyme ~/.julia/packages/Enzyme/fpA3W/src/Enzyme.jl:524
 [27] top-level scope
    @ ~/temp/test.jl:25
 [28] include(fname::String)
    @ Base.MainInclude ./client.jl:494
 [29] top-level scope
    @ REPL[1]:1
in expression starting at /home/work/temp/test.jl:25
Some type information was truncated. Use `show(err)` to see complete types.
vchuravy commented 3 days ago

Please post the backtrace in full, you cut out important information

turiya4 commented 3 days ago

Sorry. I have updated the original post with the full error.

vchuravy commented 3 days ago

I am confused? You removed even more information?

turiya4 commented 3 days ago

Could you please let me know now? I have updated the error.

vchuravy commented 3 days ago

Thanks this information is always important:


Did not have return index set when differentiating function
 call  %6 = call { { {} addrspace(10)* }, { {} addrspace(10)* } } inttoptr (i64 123305430303696 to { { {} addrspace(10)* }, { {} addrspace(10)* } } ({} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*)*)({} addrspace(10)* addrspacecast ({}* inttoptr (i64 123305469593184 to {}*) to {} addrspace(10)*), {} addrspace(10)* nonnull %1, {} addrspace(10)* nonnull %2) #9, !dbg !19
 augmentcall  %_augmented = call { i8*, { { {} addrspace(10)* }, { {} addrspace(10)* } } } %15({} addrspace(10)* addrspacecast ({}* inttoptr (i64 123305469593184 to {}*) to {} addrspace(10)*), {} addrspace(10)* addrspacecast ({}* inttoptr (i64 123305469593184 to {}*) to {} addrspace(10)*), {} addrspace(10)* %1, {} addrspace(10)* %"'", {} addrspace(10)* %2, {} addrspace(10)* %"'1"), !dbg !19
wsmoses commented 3 days ago

Yeah so this issue is that there’s an inttoptr (presumably a runtime function) that we didn’t restore and thus didn’t properly handle

wsmoses commented 3 days ago

Oh no, actually this is an instance of the deferred codegen not triggering or something?

In particular I presume the inttoptr call is from the inner AD.

We need to fix this, but just for fun what happens if you do set_abi(Reverse, InlineABI) for the innermost call