EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
455 stars 63 forks source link

Suspected Garbage Collected Value Issue When Differentiating with Enzyme.jl #1154

Closed matinraayai closed 11 months ago

matinraayai commented 11 months ago

@wsmoses Following #1152 's fix, I now get the following error when I try to differentiate the following code:

using Redbird
using SparseArrays
using MATLAB
using Enzyme
########################################################
##   prepare simulation input
########################################################

function main()
    mat"addpath('./matlab/iso2mesh')"
    mat"addpath('./matlab/mcx')"

    prop = [0     0 1 1
    0.008 1 0 1.37
    0.016 1 0 1.37]
    prop = reshape(prop, 1, size(prop)...)

    cfg = Redbird.Structs.RBConfig()
    cfg.node, cfg.face, cfg.elem = mxcall(:meshabox, 3, [0 0 0], [60 60 30], 1)
    nn = size(cfg.node, 1)
    cfg.seg = ones(size(cfg.elem, 1), 1)

    (xi, yi) = mxcall(:meshgrid, 2, 60:20:140,20:20:100)

    cfg.srcpos = hcat(xi[:], yi[:], zeros(length(yi), 1))
    cfg.detpos = hcat(xi[:], yi[:], 60 * ones(length(yi), 1))
    cfg.srcdir = [0 0 1]
    cfg.detdir = [0 0 -1]

    # cfg.omega = 2 * pi * 70e6
    cfg.omega = [0]

    cfg.wavelengths = [""]

    function forward_solver(prop::Array{Float64}, 
        ∇ϕ_i∇ϕ_j::Array{Float64}, detval::Array{Float64})
        ########################################################
        ##   Build LHS
        ########################################################

        wavelengths = cfg.wavelengths

        Amat, ∇ϕ_i∇ϕ_j = Redbird.Forward.rbfemlhs(cfg, prop, ∇ϕ_i∇ϕ_j, 1)

        (rhs, loc, bary, optode) = Redbird.Forward.rbfemrhs(cfg, prop)
        # ########################################################
        # ##   Solve for solutions at all freenodes: Afree*sol=rhs
        # ########################################################
        ϕ = Redbird.Forward.rbfemsolve(Amat, rhs, :qmr)
        # ########################################################
        # ##   Extract detector readings from the solutions
        # ########################################################

        (detval, goodix) = Redbird.Forward.rbfemgetdet(ϕ, cfg, loc, bary)
        return nothing
    end

    cfg = Redbird.Forward.rbmeshprep(cfg, prop)
    ∇s = Redbird.Forward.rb∇̇ϕ_i∇ϕ_j(cfg)
    ∇ϕ_i∇ϕ_j = ∇s.∇ϕ_i∇ϕ_j
    d_∇ϕ_i∇ϕ_j = similar(∇ϕ_i∇ϕ_j)
    # ∇ϕ = ∇s.∇ϕ

    dprop = zeros(Float64, 1, 3, 4)

    detval = zeros(Float64, 25, 25)

    d_detval = ones(Float64, 25, 25)
    # Enzyme.API.runtimeActivity!(true)
    # forward_solver(prop, cfg, ∇ϕ_i∇ϕ_j, detval)
    Enzyme.autodiff(Reverse, forward_solver, Duplicated(prop, dprop),
    Duplicated(∇ϕ_i∇ϕ_j, d_∇ϕ_i∇ϕ_j),
    Duplicated(detval, d_detval));

    @show dprop
end

main()

This is the new error I get:

(v, inst, nty) = (LLVM.BitCastInst(%435 = bitcast i8 addrspace(13)* addrspace(11)* %.phi.trans.insert551.phi.trans.insert to {} addrspace(11)*, !dbg !815), LLVM.PHIInst(%.pre552 = phi i8 addrspace(13)* [ %.pre552.pre, %L1083.L1088_crit_edge ], [ %.pre527, %L917 ], [ %.pre527, %L891 ], [ %.pre527, %L909 ], !dbg !828), LLVM.PointerType({} addrspace(10)*))
ERROR: LoadError: AssertionError: value_type(v) == nty
Stacktrace:
  [1] nodecayed_phis!(mod::LLVM.Module)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/lCjdq/src/compiler/optimize.jl:219
  [2] optimize!
    @ ~/.julia/packages/Enzyme/lCjdq/src/compiler/optimize.jl:1148 [inlined]
  [3] codegen(output::Symbol, job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/lCjdq/src/compiler.jl:4578
  [4] codegen
    @ ~/.julia/packages/Enzyme/lCjdq/src/compiler.jl:4203 [inlined]
  [5] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/lCjdq/src/compiler.jl:5146
  [6] _thunk
    @ ~/.julia/packages/Enzyme/lCjdq/src/compiler.jl:5146 [inlined]
  [7] cached_compilation
    @ ~/.julia/packages/Enzyme/lCjdq/src/compiler.jl:5180 [inlined]
  [8] (::Enzyme.Compiler.var"#488#489"{DataType, DataType, DataType, Enzyme.API.CDerivativeMode, NTuple{5, Bool}, Int64, Bool, Bool, UInt64, DataType})(ctx::LLVM.Context)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/lCjdq/src/compiler.jl:5246
  [9] JuliaContext(f::Enzyme.Compiler.var"#488#489"{DataType, DataType, DataType, Enzyme.API.CDerivativeMode, NTuple{5, Bool}, Int64, Bool, Bool, UInt64, DataType})
    @ GPUCompiler ~/.julia/packages/GPUCompiler/U36Ed/src/driver.jl:47
 [10] #s1040#487
    @ ~/.julia/packages/Enzyme/lCjdq/src/compiler.jl:5198 [inlined]
 [11] var"#s1040#487"(FA::Any, A::Any, TT::Any, Mode::Any, ModifiedBetween::Any, width::Any, ReturnPrimal::Any, ShadowInit::Any, World::Any, ABI::Any, ::Any, #unused#::Type, #unused#::Type, #unused#::Type, tt::Any, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Any)
    @ Enzyme.Compiler ./none:0
 [12] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
    @ Core ./boot.jl:602
 [13] autodiff(::ReverseMode{false, FFIABI}, ::Const{typeof(forward_solver)}, ::Type{Const{Nothing}}, ::Duplicated{Array{Float64, 3}}, ::Vararg{Any})
    @ Enzyme ~/.julia/packages/Enzyme/lCjdq/src/Enzyme.jl:209
 [14] autodiff(::ReverseMode{false, FFIABI}, ::Const{typeof(forward_solver)}, ::Duplicated{Array{Float64, 3}}, ::Const{Redbird.Structs.RBConfig}, ::Vararg{Any})
    @ Enzyme ~/.julia/packages/Enzyme/lCjdq/src/Enzyme.jl:238
 [15] autodiff
    @ ~/.julia/packages/Enzyme/lCjdq/src/Enzyme.jl:224 [inlined]
 [16] main()
    @ Main /scratch/raayaiardakani.m/Redbird.jl/example/redbird_enzyme.jl:75
 [17] top-level scope
    @ /scratch/raayaiardakani.m/Redbird.jl/example/redbird_enzyme.jl:83
in expression starting at /scratch/raayaiardakani.m/Redbird.jl/example/redbird_enzyme.jl:83

The repository in question is here with its submodules. The file in question is here. It can be run with julia --project example/redbird_enzyme.jl from the top level folder.

The variables that I'm interested in are prop and detval. Every other input should remain a constant (even though I've duplicated them here).

Note you need a working Matlab installation to generate the cfg to pass to the code. If not, maybe Octave.jl can be a drop-in replacement to run this code. The differentiable code doesn't perform calls to Matlab and it (should) be in pure Julia.

The LLVM IR dumped by Enzyme.jl can be found here.

wsmoses commented 11 months ago

I think this PR (https://github.com/EnzymeAD/Enzyme.jl/pull/1155) may fix it. Going to tentatively close this, reopen it if not?

matinraayai commented 11 months ago

@wsmoses I moved away from using the Cfg Dict and an explicit Union{T, Nothing} struct, since Enzyme didn't like them for now. I'll get back to this after I have a basic code without a cfg structure working. Since moving away my code doesn't throw this error anymore. It has other issues, which is better to discuss over other issues.

Thanks @wsmoses!