EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
428 stars 59 forks source link

StackOverflowError depending on loaded packages #1390

Closed just-walk closed 3 months ago

just-walk commented 3 months ago

I've run into an issue where the following code will produce an error, depending on which packages I have loaded.

With this set of package declarations, e.g., the result is as expected. Input A:

using LinearAlgebra
using Enzyme

function summat(p::Real)
    A = [p 1-p;
         1+p p]
    return sum(A)
end

@show autodiff(ReverseWithPrimal, summat, Active, Active(3.0))

Output A:

autodiff(ReverseWithPrimal, summat, Active, Active(3.0)) = ((2.0,), 8.0)
((2.0,), 8.0)

With this set of package declarations, an error is encountered. Input B:

using FiniteDifferences
using Enzyme

function summat(p::Real)
    A = [p 1-p;
         1+p p]
    return sum(A)
end

@show autodiff(ReverseWithPrimal, summat, Active, Active(3.0))

Output B:

ERROR: LoadError: StackOverflowError:
Stacktrace:
     [1] augmented_julia_hvcat_3949_inner_1wrap
       @ /usr/share/julia/stdlib/v1.10/SparseArrays/src/sparsevector.jl:0
     [2] macro expansion
       @ ~/.julia/packages/Enzyme/MIIMf/src/compiler.jl:5594 [inlined]
     [3] enzyme_call
       @ ~/.julia/packages/Enzyme/MIIMf/src/compiler.jl:5272 [inlined]
     [4] AugmentedForwardThunk
       @ ~/.julia/packages/Enzyme/MIIMf/src/compiler.jl:5165 [inlined]
     [5] runtime_generic_augfwd(activity::Type{…}, width::Val{…}, ModifiedBetween::Val{…}, RT::Val{…}, f::typeof(hvcat), df::Nothing, primal_1::Tuple{…}, shadow_1_1::Nothing, primal_2::Float64, shadow_2_1::Base.RefValue{…}, primal_3::Float64, shadow_3_1::Base.RefValue{…}, primal_4::Float64, shadow_4_1::Base.RefValue{…}, primal_5::Float64, shadow_5_1::Base.RefValue{…})
       @ Enzyme.Compiler ~/.julia/packages/Enzyme/MIIMf/src/rules/jitrules.jl:179
     [6] hvcat
       @ /usr/share/julia/stdlib/v1.10/SparseArrays/src/sparsevector.jl:1269 [inlined]
     [7] hvcat
       @ /usr/share/julia/stdlib/v1.10/SparseArrays/src/sparsevector.jl:0 [inlined]
--- the last 7 lines are repeated 8930 more times ---
 [62518] augmented_julia_hvcat_3949_inner_1wrap
       @ /usr/share/julia/stdlib/v1.10/SparseArrays/src/sparsevector.jl:0
 [62519] macro expansion
       @ ~/.julia/packages/Enzyme/MIIMf/src/compiler.jl:5594 [inlined]
 [62520] enzyme_call
       @ ~/.julia/packages/Enzyme/MIIMf/src/compiler.jl:5272 [inlined]
 [62521] AugmentedForwardThunk
       @ ~/.julia/packages/Enzyme/MIIMf/src/compiler.jl:5165 [inlined]
 [62522] runtime_generic_augfwd(activity::Type{…}, width::Val{…}, ModifiedBetween::Val{…}, RT::Val{…}, f::typeof(hvcat), df::Nothing, primal_1::Tuple{…}, shadow_1_1::Nothing, primal_2::Float64, shadow_2_1::Base.RefValue{…}, primal_3::Float64, shadow_3_1::Base.RefValue{…}, primal_4::Float64, shadow_4_1::Base.RefValue{…}, primal_5::Float64, shadow_5_1::Base.RefValue{…})
       @ Enzyme.Compiler ~/.julia/packages/Enzyme/MIIMf/src/rules/jitrules.jl:179
 [62523] hvcat
       @ /usr/share/julia/stdlib/v1.10/SparseArrays/src/sparsevector.jl:1269 [inlined]
 [62524] summat
       @ ~/.julia/dev/SaturationModelEigs/examples/enzyme-mwe.jl:12 [inlined]
 [62525] diffejulia_summat_2342wrap
       @ ~/.julia/dev/SaturationModelEigs/examples/enzyme-mwe.jl:0
 [62526] macro expansion
       @ ~/.julia/packages/Enzyme/MIIMf/src/compiler.jl:5594 [inlined]
 [62527] enzyme_call(::Val{…}, ::Ptr{…}, ::Type{…}, ::Type{…}, ::Val{…}, ::Type{…}, ::Type{…}, ::Const{…}, ::Type{…}, ::Active{…}, ::Float64)
       @ Enzyme.Compiler ~/.julia/packages/Enzyme/MIIMf/src/compiler.jl:5272
 [62528] (::Enzyme.Compiler.CombinedAdjointThunk{…})(::Const{…}, ::Active{…}, ::Vararg{…})
       @ Enzyme.Compiler ~/.julia/packages/Enzyme/MIIMf/src/compiler.jl:5154
 [62529] autodiff
       @ ~/.julia/packages/Enzyme/MIIMf/src/Enzyme.jl:275 [inlined]
 [62530] autodiff(mode::ReverseMode{true, FFIABI, false}, f::typeof(summat), ::Type{Active}, args::Active{Float64})
       @ Enzyme ~/.julia/packages/Enzyme/MIIMf/src/Enzyme.jl:287
 [62531] macro expansion
       @ show.jl:1181 [inlined]
 [62532] top-level scope
       @ ~/.julia/dev/SaturationModelEigs/examples/enzyme-mwe.jl:17
 [62533] include(fname::String)
       @ Base.MainInclude ./client.jl:489
in expression starting at /home/jwalker/.julia/dev/SaturationModelEigs/examples/enzyme-mwe.jl:17
Some type information was truncated. Use `show(err)` to see complete types.

It looks like the issue is with SparseArrays.jl. If I replace using FiniteDifferences above with using SparseArrays, then there is a similar error. What's going on? I'm not explicitly using any SparseArray types, but it seems to be running into issues trying to compile them.

wsmoses commented 3 months ago

Yeah this is an unfortunate issue/bug introduced in Julia 1.10 into SparseArrays here (https://github.com/JuliaSparse/SparseArrays.jl/commit/c402d09cf05492179fad2def5632e354a81f5b30) which will occur during a type unstable array constructor (or hvcat more specifically).

I'd recommend you allocate you array with undef then set the value as a workaround, or use a Julia before the bug was introduced in 1.10.

We should look more closely at this, but closing as duplicate issue of https://github.com/EnzymeAD/Enzyme.jl/issues/1134