EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
439 stars 62 forks source link

`Enzyme execution failed` with `Functors.jl` #916

Closed zuhengxu closed 1 year ago

zuhengxu commented 1 year ago

Example:

using Enzyme, Functors, Optimisers

struct MyShift{T}
    a::T
end

Functors.@functor MyShift

(s::MyShift)(x) = x .+ s.a

s = MyShift(ones(2))
x = randn(2)

# `destructure` collects all the trainable parameters in a vector, and returns this along with a function to re-build a similar structure from the vector
ps, restructure = Optimisers.destructure(s)

func(ps_) = sum(abs2, restructure(ps_)(ones(2)))

θ = randn(2)
∇θ = zeros(2)
Enzyme.API.runtimeActivity!(true)
Enzyme.autodiff(Enzyme.ReverseWithPrimal, func, Enzyme.Active, Enzyme.Duplicated(θ, ∇θ))

error message:

ERROR: Enzyme execution failed.
Enzyme: Not yet implemented augmented forward for jl_eqtable_get
Stacktrace:
 [1] get
   @ ./iddict.jl:102
 [2] in
   @ ./iddict.jl:189
 [3] haskey
   @ ./abstractdict.jl:17
 [4] CachedWalk
   @ ~/.julia/packages/Functors/8Lxin/src/walks.jl:129
 [5] recurse
   @ ~/.julia/packages/Functors/8Lxin/src/maps.jl:6
 [6] #52
   @ ~/.julia/packages/Optimisers/1x8gl/src/destructure.jl:115
 [7] map
   @ ./tuple.jl:318
 [8] map
   @ ./namedtuple.jl:219
 [9] map
   @ ./namedtuple.jl:0

Stacktrace:
  [1] throwerr(cstr::Cstring)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/gS4lp/src/compiler.jl:2924
  [2] macro expansion
    @ ~/.julia/packages/Enzyme/gS4lp/src/compiler.jl:9552 [inlined]
  [3] enzyme_call
    @ ~/.julia/packages/Enzyme/gS4lp/src/compiler.jl:9247 [inlined]
  [4] AugmentedForwardThunk
    @ ~/.julia/packages/Enzyme/gS4lp/src/compiler.jl:9221 [inlined]
  [5] runtime_generic_augfwd(activity::Val{(false, true, true, true, false)}, width::Val{1}, ModifiedBetween::Val{(true, true, true, true, true)}, RT::Val{NamedTuple{(Symbol("1"), Symbol("2"), Symbol("3")), Tuple{Any, Any, Any}}}, f::typeof(map), df::Nothing, primal_1::Optimisers.var"#52#53"{Functors.var"#recurse#25"{Functors.CachedWalk{Functors.ExcludeWalk{Optimisers._Trainable_biwalk, Optimisers.var"#50#51"{Vector{Float64}}, typeof(Optimisers.isnumeric)}, Functors.NoKeyword}}}, shadow_1_1::Optimisers.var"#52#53"{Functors.var"#recurse#25"{Functors.CachedWalk{Functors.ExcludeWalk{Optimisers._Trainable_biwalk, Optimisers.var"#50#51"{Vector{Float64}}, typeof(Optimisers.isnumeric)}, Functors.NoKeyword}}}, primal_2::NamedTuple{(:a,), Tuple{Vector{Float64}}}, shadow_2_1::NamedTuple{(:a,), Tuple{Vector{Float64}}}, primal_3::NamedTuple{(:a,), Tuple{Vector{Float64}}}, shadow_3_1::NamedTuple{(:a,), Tuple{Vector{Float64}}}, primal_4::NamedTuple{(:a,), Tuple{Int64}}, shadow_4_1::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/gS4lp/src/compiler.jl:1293
  [6] macro expansion
    @ ~/.julia/packages/Enzyme/gS4lp/src/compiler.jl:9552 [inlined]
  [7] enzyme_call
    @ ~/.julia/packages/Enzyme/gS4lp/src/compiler.jl:9247 [inlined]
  [8] AugmentedForwardThunk
    @ ~/.julia/packages/Enzyme/gS4lp/src/compiler.jl:9221 [inlined]
  [9] runtime_generic_augfwd(activity::Val{(false, true, true, false)}, width::Val{1}, ModifiedBetween::Val{(true, true, true, true)}, RT::Val{NamedTuple{(Symbol("1"), Symbol("2"), Symbol("3")), Tuple{Any, Any, Any}}}, f::Optimisers._Trainable_biwalk, df::Nothing, primal_1::Functors.var"#recurse#25"{Functors.CachedWalk{Functors.ExcludeWalk{Optimisers._Trainable_biwalk, Optimisers.var"#50#51"{Vector{Float64}}, typeof(Optimisers.isnumeric)}, Functors.NoKeyword}}, shadow_1_1::Functors.var"#recurse#25"{Functors.CachedWalk{Functors.ExcludeWalk{Optimisers._Trainable_biwalk, Optimisers.var"#50#51"{Vector{Float64}}, typeof(Optimisers.isnumeric)}, Functors.NoKeyword}}, primal_2::MyShift{Vector{Float64}}, shadow_2_1::MyShift{Vector{Float64}}, primal_3::NamedTuple{(:a,), Tuple{Int64}}, shadow_3_1::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/gS4lp/src/compiler.jl:1293
 [10] macro expansion
    @ ~/.julia/packages/Enzyme/gS4lp/src/compiler.jl:9552 [inlined]
 [11] enzyme_call
    @ ~/.julia/packages/Enzyme/gS4lp/src/compiler.jl:9247 [inlined]
 [12] (::Enzyme.Compiler.AugmentedForwardThunk{Const{Optimisers.Restructure{MyShift{Vector{Float64}}, NamedTuple{(:a,), Tuple{Int64}}}}, Duplicated{MyShift}, Tuple{Duplicated{Vector{Float64}}}, Val{1}, Val{true}(), NamedTuple{(Symbol("1"), Symbol("2"), Symbol("3"), Symbol("4"), Symbol("5"), Symbol("6"), Symbol("7"), Symbol("8"), Symbol("9")), NTuple{9, Any}}})(fn::Const{Optimisers.Restructure{MyShift{Vector{Float64}}, NamedTuple{(:a,), Tuple{Int64}}}}, args::Duplicated{Vector{Float64}})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/gS4lp/src/compiler.jl:9221
 [13] runtime_generic_augfwd(activity::Val{(false, true)}, width::Val{1}, ModifiedBetween::Val{(true, true)}, RT::Val{NamedTuple{(Symbol("1"), Symbol("2"), Symbol("3")), Tuple{Any, Any, Any}}}, f::Optimisers.Restructure{MyShift{Vector{Float64}}, NamedTuple{(:a,), Tuple{Int64}}}, df::Nothing, primal_1::Vector{Float64}, shadow_1_1::Vector{Float64})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/gS4lp/src/compiler.jl:1293
 [14] macro expansion
    @ ~/.julia/packages/Enzyme/gS4lp/src/compiler.jl:9552 [inlined]
 [15] enzyme_call
    @ ~/.julia/packages/Enzyme/gS4lp/src/compiler.jl:9247 [inlined]
 [16] AugmentedForwardThunk
    @ ~/.julia/packages/Enzyme/gS4lp/src/compiler.jl:9221 [inlined]
 [17] autodiff(#unused#::EnzymeCore.ReverseMode{true}, f::Const{typeof(func)}, #unused#::Type{Active}, args::Duplicated{Vector{Float64}})
    @ Enzyme ~/.julia/packages/Enzyme/gS4lp/src/Enzyme.jl:188
 [18] autodiff(::EnzymeCore.ReverseMode{true}, ::typeof(func), ::Type, ::Duplicated{Vector{Float64}})
    @ Enzyme ~/.julia/packages/Enzyme/gS4lp/src/Enzyme.jl:214
 [19] top-level scope
    @ ~/Research/NormalizingFlows.jl/test/test.jl:25

Looks like Enzyme might have trouble dealing with closures? I'm not certain byt this could be weakly related to #700 .

System info:

Julia version: v"1.9.1"
os: Ubuntu 22.04.2 LTS
wsmoses commented 1 year ago

The issue is not the struct but the (type unstable?) dictionary lookup. See the backtrace for more information.

In any case, duplicate of https://github.com/EnzymeAD/Enzyme.jl/issues/416

torfjelde commented 1 year ago

Even if we remove the caching by overloading some of the Optimisers.jl functionality, it still fails, though for seemingly different reasons:

julia> using Enzyme, Functors, Optimisers

julia> function Optimisers._flatten(x)
         Optimisers.isnumeric(x) && return vcat(_vec(x)), 0, length(x)  # trivial case
         arrays = AbstractVector[]
         len = Ref(0)
         off = fmap(x; exclude = Optimisers.isnumeric, walk = Optimisers._TrainableStructWalk(), cache = nothing) do y
           push!(arrays, Optimisers._vec(y))
           o = len[]
           len[] = o + length(y)
           o
         end
         isempty(arrays) && return Bool[], off, 0
         reduce(vcat, arrays), off, len[]
       end

julia> function Optimisers._rebuild(x, off, flat::AbstractVector, len = length(flat); walk = Optimisers._Trainable_biwalk(), kw...)
         len == length(flat) || throw(DimensionMismatch("Rebuild expected a vector of length $len, got $(length(flat))"))
         fmap(x, off; exclude = Optimisers.isnumeric, walk, cache = nothing, kw...) do y, o
             Optimisers._getat(y, o, flat)
         end
       end

julia> struct MyShift{T}
           a::T
       end

julia> Functors.@functor MyShift

julia> (s::MyShift)(x) = x .+ s.a

julia> s = MyShift(ones(2))
MyShift{Vector{Float64}}([1.0, 1.0])

julia> x = randn(2)
2-element Vector{Float64}:
 -0.19679963521399926
  1.295242450296637

julia> # `destructure` collects all the trainable parameters in a vector, and returns this along with a function to re-build a similar structure from the vector
       ps, restructure = Optimisers.destructure(s)
([1.0, 1.0], Restructure(MyShift, ..., 2))

julia> func(ps_) = sum(abs2, restructure(ps_)(ones(2)))
func (generic function with 1 method)

julia> θ = randn(2)
2-element Vector{Float64}:
  0.55035203569321
 -0.49582453100445745

julia> ∇θ = zeros(2)
2-element Vector{Float64}:
 0.0
 0.0

julia> func(θ)
2.657784338114955

julia> Enzyme.API.runtimeActivity!(true)

julia> Enzyme.autodiff(Enzyme.ReverseWithPrimal, func, Enzyme.Active, Enzyme.Duplicated(θ, ∇θ))
warning: didn't implement memmove, using memcpy as fallback which can result in errors
((nothing,), 2.657784338114955)
Setup ```julia (jl_M6SMPY) pkg> st --manifest Status `/tmp/jl_M6SMPY/Manifest.toml` [79e6a3ab] Adapt v3.6.2 [fa961155] CEnum v0.4.2 [d360d2e6] ChainRulesCore v1.16.0 [34da2185] Compat v4.7.0 [7da242da] Enzyme v0.11.4 [f151be2c] EnzymeCore v0.5.1 [e2ba6199] ExprTools v0.1.9 [d9f16b24] Functors v0.4.5 ⌃ [61eb1bfa] GPUCompiler v0.21.0 [692b3bcd] JLLWrappers v1.4.1 ⌅ [929cbde3] LLVM v5.2.0 [d8793406] ObjectFile v0.4.0 [3bd65402] Optimisers v0.2.18 [21216c6a] Preferences v1.4.0 [189a3867] Reexport v1.2.2 [ae029012] Requires v1.3.0 [6c6a2e73] Scratch v1.2.0 [53d494c1] StructIO v0.3.0 [a759f4b9] TimerOutputs v0.5.23 ⌅ [7cc45869] Enzyme_jll v0.0.74+0 ⌅ [dad2f222] LLVMExtra_jll v0.0.22+0 [0dad84c5] ArgTools v1.1.1 [56f22d72] Artifacts [2a0f44e3] Base64 [ade2ca70] Dates [f43a241f] Downloads v1.6.0 [7b1f6079] FileWatching [b77e0a4c] InteractiveUtils [4af54fe1] LazyArtifacts [b27032c2] LibCURL v0.6.3 [76f85450] LibGit2 [8f399da3] Libdl [37e2e46d] LinearAlgebra [56ddb016] Logging [d6f4376e] Markdown [ca575930] NetworkOptions v1.2.0 [44cfe95a] Pkg v1.9.0 [de0858da] Printf [3fa0cd96] REPL [9a3f8284] Random [ea8e919c] SHA v0.7.0 [9e88b42a] Serialization [6462fe0b] Sockets [2f01184e] SparseArrays [10745b16] Statistics v1.9.0 [fa267f1f] TOML v1.0.3 [a4e569a6] Tar v1.10.0 [8dfed614] Test [cf7118a7] UUIDs [4ec0a83e] Unicode [e66e0078] CompilerSupportLibraries_jll v1.0.2+0 [deac9b47] LibCURL_jll v7.84.0+0 [29816b5a] LibSSH2_jll v1.10.2+0 [c8ffd9c3] MbedTLS_jll v2.28.2+0 [14a3606d] MozillaCACerts_jll v2022.10.11 [4536629a] OpenBLAS_jll v0.3.21+4 [bea87d4a] SuiteSparse_jll v5.10.1+6 [83775a58] Zlib_jll v1.2.13+0 [8e850b90] libblastrampoline_jll v5.7.0+0 [8e850ede] nghttp2_jll v1.48.0+0 [3f19e933] p7zip_jll v17.4.0+0 Info Packages marked with ⌃ and ⌅ have new versions available, but those with ⌅ are restricted by compatibility constraints from upgrading. To see why use `status --outdated -m` julia> versioninfo() Julia Version 1.9.0 Commit 8e630552924 (2023-05-07 11:25 UTC) Platform Info: OS: Linux (x86_64-linux-gnu) CPU: 12 × Intel(R) Core(TM) i7-10710U CPU @ 1.10GHz WORD_SIZE: 64 LIBM: libopenlibm LLVM: libLLVM-14.0.6 (ORCJIT, skylake) Threads: 1 on 12 virtual cores ```
vchuravy commented 1 year ago

Even if we remove the caching by overloading some of the Optimisers.jl functionality, it still fails, though for seemingly different reasons:

What do you mean it still fails? The code seems to run to completion. Is the gradient wrong?

wsmoses commented 1 year ago

@torfjelde what do yo mean by failed, it looks like it ran?

Edit: lol jinx on response race condition @vchuravy

torfjelde commented 1 year ago

Haha sorry, yeah the gradient is incorrect. Should be ones(2), no?

wsmoses commented 1 year ago

In reverse mode, duplicated means updates the derivative in place. Did you check the value of ∇θ after the run?

The 2.67 number is the result of the function func (since you requested ReverseWithPrimal)

sunxd3 commented 1 year ago

Yep, for a separate run

julia> θ = randn(2)
2-element Vector{Float64}:
  1.3248424332978874
 -0.4331324303513132

... # running the gradient function

julia> ∇θ
2-element Vector{Float64}:
 4.649684866595775
 1.1337351392973736

This is correct, because the gradient should be [2(θ[1]+1), 2(θ[2]+1)]

torfjelde commented 1 year ago

Aaah perfect; sorry, didn't think too much about what was going on but just gave a quick attempt at a fix.Thank you so much for the help @wsmoses and @vchuravy