EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
450 stars 64 forks source link

Enzyme using fallback BLAS replacements in deep neural network #692

Closed jarroyoe closed 1 year ago

jarroyoe commented 1 year ago

When I run this code:

using OrdinaryDiffEq, DiffEqFlux
using Optimization, OptimizationOptimisers, OptimizationOptimJL
using ComponentArrays, Lux, Plots, Random, StatsBase
using DelimitedFiles, Serialization
rng = Random.default_rng()

neuralNetwork = Lux.Chain(Lux.Dense(2,10),Lux.Dense(10,5),Lux.Dense(5,2))
ps, st = Lux.setup(rng, neuralNetwork)
ps64 = Float64.(ComponentArray(ps))
knownDynamics(x,p,q) = [-p[1].*x[1];
                    -p[2].*x[2]];
training_data = rand(2,50)
ps_dynamics = ComponentArray((predefined_params = rand(Float64, 2), model_params = ps64))
function ude!(du,u,p,t,q)
        knownPred = knownDynamics(u,p.predefined_params,q)
        nnPred = Array(neuralNetwork(u,p.model_params,st)[1])

        for i in 1:length(u)
            du[i] = knownPred[i]+nnPred[i]
        end
end
nn_dynamics!(du,u,p,t) = ude!(du,u,p,t,nothing)
prob_nn = ODEProblem(nn_dynamics!,training_data[:, 1], (Float64(1),Float64(size(training_data,2))), ps_dynamics)
function predict(p, X = training_data[:,1], T = 1:size(training_data,2))
        _prob = remake(prob_nn, u0 = X, tspan = (Float64(T[1]), Float64(T[end])), p = p)
        Array(solve(_prob, Rodas4P(), saveat = T,
                abstol=1e-7, reltol=1e-7
                ))
end

function loss_function(p)
        X̂ = predict(p)
        sum(abs2, training_data .- X̂)
end
pinit = ComponentVector(ps_dynamics)
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_function(x), adtype)
optprob = Optimization.OptimizationProblem(optf, pinit)
result_neuralode = Optimization.solve(optprob,
                                           ADAM(),
                                           maxiters = 10)

I get repetitions of the following warnings, and performance is significantly slowed down.

┌ Warning: Using fallback BLAS replacements, performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/Z5kZC/src/utils.jl:35
warning: didn't implement memmove, using memcpy as fallback which can result in errors
warning: Linking two modules of different target triples: 'bcloader' is 'x86_64-unknown-linux-gnu' whereas 'text' is 'x86_64-redhat-linux'

However, if I change my neural network architecture to include a single hidden layer: neuralNetwork = Lux.Chain(Lux.Dense(2,10),Lux.Dense(10,2))

The warnings disappear. The warnings also disappear when my neural network has a single input:

neuralNetwork = Lux.Chain(Lux.Dense(1,10),Lux.Dense(10,5),Lux.Dense(5,1))
ps, st = Lux.setup(rng, neuralNetwork)
ps64 = Float64.(ComponentArray(ps))
knownDynamics(x,p,q) = -p[1].*x[1]
training_data = rand(1,50)

I haven't tested having a single layer neural network on production scale, but the case with a single input has a significantly better performance than the case with two inputs (the single input completes the production in less than a day, the double input takes several days to go through a single iteration of the for loop and runs out of 250GB of RAM).

Besides changing my neural networks to a single hidden layer (not ideal), how can this issue be fixed?

wsmoses commented 1 year ago

cc @ChrisRackauckas

THe latter warning implies you aren't using the latest main. What happens when you use the latest Enzyme?

Secondly, from AutoZygote I'm not sure you're running with Enzyme ehre.

ChrisRackauckas commented 1 year ago

No, it's using Enzyme in one of the adjoints.

Though last I checked on main it'll still fail with Lux until https://github.com/EnzymeAD/Enzyme.jl/issues/645 is fixed. The getfield missing derivatives are kind of the last piece of the puzzle.

wsmoses commented 1 year ago

Yeah but I presume the Lux AD they are doing successfully isn't going through Enzyme, right? Thus the perf issue they see is Zygote?

ChrisRackauckas commented 1 year ago

No it would be the Enzyme BLAS fallback.

wsmoses commented 1 year ago

Hm okay, in that case -- so I can properly understand, can you make a version of this code that just is a call to Enzyme.autodiff ?

ChrisRackauckas commented 1 year ago

It would just be the UDE part:

function ude!(du,u,p,t,q)
        knownPred = knownDynamics(u,p.predefined_params,q)
        nnPred = Array(neuralNetwork(u,p.model_params,st)[1])

        for i in 1:length(u)
            du[i] = knownPred[i]+nnPred[i]
        end
end

w.r.t. u,p,

wsmoses commented 1 year ago

Can you give an example that initializes the internals? Am not familiar with those packages.

using Enzyme, Random, Lux, ComponentArrays

rng = Random.default_rng()

neuralNetwork = Lux.Chain(Lux.Dense(2,10),Lux.Dense(10,5),Lux.Dense(5,2))
ps, st = Lux.setup(rng, neuralNetwork)
ps64 = Float64.(ComponentArray(ps))
knownDynamics(x,p,q) = [-p[1].*x[1];
                    -p[2].*x[2]];
training_data = rand(2,50)
ps_dynamics = ComponentArray((predefined_params = rand(Float64, 2), model_params = ps64))
function ude!(du,u,p,t,q)
        knownPred = knownDynamics(u,p.predefined_params,q)
        nnPred = Array(neuralNetwork(u,p.model_params,st)[1])

        for i in 1:length(u)
            du[i] = knownPred[i]+nnPred[i]
        end
end

du = ?
d_du = ?

u = ?
d_u = ?
p = ?
d_p = ?
t = ?
q = ?
Enzyme.autodiff(Reverse, ude!, Duplicated(du, d_du), Duplicated(u, d_u), Const(p), Const(d_p), Const(t), Const(q))
ChrisRackauckas commented 1 year ago
using Enzyme, Random, Lux, ComponentArrays

rng = Random.default_rng()

neuralNetwork = Lux.Chain(Lux.Dense(2,10),Lux.Dense(10,5),Lux.Dense(5,2))
ps, st = Lux.setup(rng, neuralNetwork)
ps64 = Float64.(ComponentArray(ps))
knownDynamics(x,p) = [-p[1].*x[1];
                    -p[2].*x[2]];
training_data = rand(2,50)
ps_dynamics = ComponentArray((predefined_params = rand(Float64, 2), model_params = ps64))
function ude!(du,u,p,t)
    knownPred = knownDynamics(u,p.predefined_params)
    nnPred = Array(neuralNetwork(u,p.model_params,st)[1])

    for i in 1:length(u)
        du[i] = knownPred[i]+nnPred[i]
    end
    nothing
end

du = training_data[:,1]
d_du = training_data[:,1]

u = training_data[:,1]
d_u = training_data[:,1]
p = ps_dynamics
d_p = copy(ps_dynamics)
t = 0.0

ude!(du,u,p,t)
Enzyme.autodiff(Reverse, ude!, Duplicated(du, d_du), Duplicated(u, d_u), Const(p), Const(t))
Enzyme.autodiff(Reverse, ude!, Duplicated(du, d_du), Const(u), DuplicatedNoNeed(p, d_p), Const(t))
jeremiedb commented 1 year ago

Hope this is not tangent to the specifics of Lux above, but back to original question regarding slow performance associated with BLAS fallback, this is something that is experiened in the following simple setup:

function mymatmul(x::AbstractMatrix, w::AbstractMatrix)
    out = sum(w * x)
    return out
end;

seed!(123)
bs = 4096
f = 256
h1 = 512
w = randn(h1, f) .* 0.01;
x = randn(f, bs) .* 0.01;
dw = zeros(h1, f);
# dx = zeros(f, bs);

Forward pass (0.01 sec)

julia> @time mymatmul(x, w)
  0.012209 seconds (3 allocations: 16.000 MiB)
0.5842598323078428

Forward-backward (5 secs, over 100X slower)

julia> @time _, y = Enzyme.autodiff(ReverseWithPrimal, mymatmul, Const(x), Duplicated(w, dw))
  5.066479 seconds (4.11 k allocations: 32.125 MiB)
((nothing, nothing), 0.5842598323078416)

This is from current main branch version:

  [7da242da] Enzyme v0.11.0 `https://github.com/EnzymeAD/Enzyme.jl.git#main`
  [f151be2c] EnzymeCore v0.2.1 `https://github.com/EnzymeAD/Enzyme.jl.git:lib/EnzymeCore#main`
wsmoses commented 1 year ago

@ChrisRackauckas yeah this does not differentiate with Enzyme atm so I'm presuming this is actually Zygote or a different fallback causing the issues.

No augmented forward pass found for ijl_box_char
declare nonnull {} addrspace(10)* @ijl_box_char(i32 zeroext) local_unnamed_addr #3

Stacktrace:
  [1] julia_error(cstr::Cstring, val::Ptr{LLVM.API.LLVMOpaqueValue}, errtype::Enzyme.API.ErrorType, data::Ptr{Nothing})
    @ Enzyme.Compiler ~/git/Enzyme.jl/src/compiler.jl:5040
  [2] macro expansion
    @ ./logging.jl:362 [inlined]
  [3] macro expansion
    @ ~/.julia/packages/GPUCompiler/anMCs/src/utils.jl:58 [inlined]
  [4] inlining_policy
    @ ~/git/Enzyme.jl/src/compiler/interpreter.jl:181
  [5] EnzymeCreateAugmentedPrimal(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{Enzyme.API.CDIFFE_TYPE}, TA::Enzyme.TypeAnalysis, returnUsed::Bool, shadowReturnUsed::Bool, typeInfo::Enzyme.FnTypeInfo, 

I'll go add the function now, regardless to get it closer to working on it -- but an FYI I don't think Enzyme is the cause.

wsmoses commented 1 year ago

Summarizing for @jarroyoe, this code does not currently differentiate with Enzyme, so Enzyme is not the cause of the performance (the warnings come from Enzyme be asked to differentiate something, though presumably not using that since Enzyme would error). Thus your performance question is best asked at the issue forum for a different library.

We'll work on making this differentiate so let's keep the issue open, regardless.

wsmoses commented 1 year ago

@jeremiedb suboptimal performance on fallback blas is expected (hence the warning about it). Feel free to open a separate issue to track. Resolving it requires either someone adding the BLAS rules in EnzymeRules or finishing up @ZuseZ4's work on Enzyme-internal blas rules.

jarroyoe commented 1 year ago

@wsmoses thank you for this information. To summarize, is this issue entirely on the Enzyme.jl team hands now, or something can be done through the DiffEqFlux.jl part of the script?

wsmoses commented 1 year ago

You could define an EnzymeRule for the unsupported part of the code.

ChrisRackauckas commented 1 year ago

The current bottleneck is getfield: https://github.com/EnzymeAD/Enzyme.jl/issues/645, https://github.com/EnzymeAD/Enzyme.jl/issues/644. What's going on in the OP's case is that it will do a try/catch on Enzyme, which throws the warning, but then fails (errors), which is caught, and then it falls back to using ReverseDiff in scalar mode (to handle the mutation), and it should be doing tape compilation, but that's all a bit besides the point. The key is that what I showed is what it trys to do with Enzyme and fails.

wsmoses commented 1 year ago
using Enzyme, Random, Lux, ComponentArrays

Enzyme.API.printall!(true)

rng = Random.default_rng()

neuralNetwork = Lux.Chain(Lux.Dense(2,1),Lux.Dense(1,2))
ps, st = Lux.setup(rng, neuralNetwork)
ps64 = Float64.(ComponentArray(ps))
# @show ps64

function ude!(u,p, neuralNetwork)
    neuralNetwork = Base.inferencebarrier(neuralNetwork.layers.layer_1)
    # y = neuralNetwork(u, p.layer_1, NamedTuple()) # st.layer_1)
    (y, _) = neuralNetwork(u, p.layer_1, NamedTuple()) # st.layer_1)

    y[1]::Float64
end

u = Float64[0,0]
d_u = Float64[0,0]
p = ps64
d_p = copy(ps64)

ude!(u,p, neuralNetwork)
Enzyme.autodiff(Reverse, ude!, Duplicated(u, d_u), Const(p), neuralNetwork)

this hits a GC error atm, investigating.

wsmoses commented 1 year ago

The GC part of the error should now be fixed, the type instability persists.

wsmoses commented 1 year ago

Okay, still needs work:


using OrdinaryDiffEq, DiffEqFlux
using Optimization, OptimizationOptimisers, OptimizationOptimJL
using ComponentArrays, Lux, Random
# using ComponentArrays, Lux, Plots, Random, StatsBase
# using DelimitedFiles, Serialization
rng = Random.default_rng()

const neuralNetwork = Lux.Chain(Lux.Dense(2,10),Lux.Dense(10,5),Lux.Dense(5,2))
const ps, st = Lux.setup(rng, neuralNetwork)
ps64 = Float64.(ComponentArray(ps))
knownDynamics(x,p,q) = [-p[1].*x[1];
                    -p[2].*x[2]];
training_data = rand(2,50)
ps_dynamics = ComponentArray((predefined_params = rand(Float64, 2), model_params = ps64))
function ude!(du,u,p,t,q)
        knownPred = knownDynamics(u,p.predefined_params,q)
        nnPred = Array(neuralNetwork(u,p.model_params,st)[1])

        for i in 1:length(u)
            du[i] = knownPred[i]+nnPred[i]
        end
end
nn_dynamics!(du,u,p,t) = ude!(du,u,p,t,nothing)
prob_nn = ODEProblem(nn_dynamics!,training_data[:, 1], (Float64(1),Float64(size(training_data,2))), ps_dynamics)
function predict(p, X = training_data[:,1], T = 1:size(training_data,2))
        _prob = remake(prob_nn, u0 = X, tspan = (Float64(T[1]), Float64(T[end])), p = p)
        Array(solve(_prob, Rodas4P(), saveat = T,
                abstol=1e-7, reltol=1e-7
                ))
end

function loss_function(p)
        X̂ = predict(p)
        sum(abs2, training_data .- X̂)
end
pinit = ComponentVector(ps_dynamics)
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_function(x), adtype)
optprob = Optimization.OptimizationProblem(optf, pinit)
result_neuralode = Optimization.solve(optprob,
                                           ADAM(),
                                           maxiters = 10)
wmoses@beast:~/git/Enzyme.jl (mt) $ ./julia-1.9.0-rc1/bin/julia --project is692.jl
┌ Warning: Using fallback BLAS replacements, performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/BxfIW/src/utils.jl:56
warning: didn't implement memmove, using memcpy as fallback which can result in errors
┌ Warning: Using fallback BLAS replacements, performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/BxfIW/src/utils.jl:56
warning: didn't implement memmove, using memcpy as fallback which can result in errors
┌ Warning: Using fallback BLAS replacements, performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/BxfIW/src/utils.jl:56
warning: didn't implement memmove, using memcpy as fallback which can result in errors
warning: didn't implement memmove, using memcpy as fallback which can result in errors

[203832] signal (11.1): Segmentation fault
in expression starting at /home/wmoses/git/Enzyme.jl/is692.jl:41
unknown function (ip: 0x7f18ea5f5443)
unknown function (ip: 0x7f18ea5f4117)
unknown function (ip: 0x7f18ea5ff542)
macro expansion at /home/wmoses/git/Enzyme.jl/src/compiler.jl:8646 [inlined]
enzyme_call at /home/wmoses/git/Enzyme.jl/src/compiler.jl:8338 [inlined]
CombinedAdjointThunk at /home/wmoses/git/Enzyme.jl/src/compiler.jl:8301 [inlined]
autodiff at /home/wmoses/git/Enzyme.jl/src/Enzyme.jl:205 [inlined]
autodiff at /home/wmoses/git/Enzyme.jl/src/Enzyme.jl:228 [inlined]
autodiff at /home/wmoses/git/Enzyme.jl/src/Enzyme.jl:214 [inlined]
_vecjacobian! at /home/wmoses/.julia/packages/SciMLSensitivity/M3EmS/src/derivative_wrappers.jl:687
#vecjacobian!#28 at /home/wmoses/.julia/packages/SciMLSensitivity/M3EmS/src/derivative_wrappers.jl:224 [inlined]
vecjacobian! at /home/wmoses/.julia/packages/SciMLSensitivity/M3EmS/src/derivative_wrappers.jl:221 [inlined]
ODEInterpolatingAdjointSensitivityFunction at /home/wmoses/.julia/packages/SciMLSensitivity/M3EmS/src/interpolating_adjoint.jl:135
ODEFunction at /home/wmoses/.julia/packages/SciMLBase/VdcHg/src/scimlfunctions.jl:2126 [inlined]
ode_determine_initdt at /home/wmoses/.julia/packages/OrdinaryDiffEq/gjQVg/src/initdt.jl:53
auto_dt_reset! at /home/wmoses/.julia/packages/OrdinaryDiffEq/gjQVg/src/integrators/integrator_interface.jl:442 [inlined]
handle_dt! at /home/wmoses/.julia/packages/OrdinaryDiffEq/gjQVg/src/solve.jl:547
#__init#632 at /home/wmoses/.julia/packages/OrdinaryDiffEq/gjQVg/src/solve.jl:509
__init at /home/wmoses/.julia/packages/OrdinaryDiffEq/gjQVg/src/solve.jl:10 [inlined]
__init at /home/wmoses/.julia/packages/OrdinaryDiffEq/gjQVg/src/solve.jl:10 [inlined]
__init at /home/wmoses/.julia/packages/OrdinaryDiffEq/gjQVg/src/solve.jl:10 [inlined]
__init at /home/wmoses/.julia/packages/OrdinaryDiffEq/gjQVg/src/solve.jl:10 [inlined]
__init at /home/wmoses/.julia/packages/OrdinaryDiffEq/gjQVg/src/solve.jl:10 [inlined]
#__solve#631 at /home/wmoses/.julia/packages/OrdinaryDiffEq/gjQVg/src/solve.jl:5 [inlined]
__solve at /home/wmoses/.julia/packages/OrdinaryDiffEq/gjQVg/src/solve.jl:1 [inlined]
#solve_call#22 at /home/wmoses/.julia/packages/DiffEqBase/ihYDa/src/solve.jl:509
solve_call at /home/wmoses/.julia/packages/DiffEqBase/ihYDa/src/solve.jl:479
unknown function (ip: 0x7f18a0158328)
_jl_invoke at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2731 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2913
#solve_up#29 at /home/wmoses/.julia/packages/DiffEqBase/ihYDa/src/solve.jl:932
solve_up at /home/wmoses/.julia/packages/DiffEqBase/ihYDa/src/solve.jl:905 [inlined]
#solve#27 at /home/wmoses/.julia/packages/DiffEqBase/ihYDa/src/solve.jl:842
solve at /home/wmoses/.julia/packages/DiffEqBase/ihYDa/src/solve.jl:832
unknown function (ip: 0x7f18a0105686)
_jl_invoke at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2731 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2913
#_adjoint_sensitivities#68 at /home/wmoses/.julia/packages/SciMLSensitivity/M3EmS/src/sensitivity_interface.jl:433
_adjoint_sensitivities at /home/wmoses/.julia/packages/SciMLSensitivity/M3EmS/src/sensitivity_interface.jl:390 [inlined]
#adjoint_sensitivities#67 at /home/wmoses/.julia/packages/SciMLSensitivity/M3EmS/src/sensitivity_interface.jl:386 [inlined]
adjoint_sensitivities at /home/wmoses/.julia/packages/SciMLSensitivity/M3EmS/src/sensitivity_interface.jl:358
unknown function (ip: 0x7f18a00e0296)
_jl_invoke at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2731 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2913
adjoint_sensitivity_backpass at /home/wmoses/.julia/packages/SciMLSensitivity/M3EmS/src/concrete_solve.jl:523
ZBack at /home/wmoses/.julia/packages/Zygote/SuKWp/src/compiler/chainrules.jl:211 [inlined]
kw_zpullback at /home/wmoses/.julia/packages/Zygote/SuKWp/src/compiler/chainrules.jl:237
#287 at /home/wmoses/.julia/packages/Zygote/SuKWp/src/lib/lib.jl:206 [inlined]
#2156#back at /home/wmoses/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71
unknown function (ip: 0x7f18a00de402)
_jl_invoke at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2731 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2913
Pullback at /home/wmoses/.julia/packages/DiffEqBase/ihYDa/src/solve.jl:842 [inlined]
Pullback at /home/wmoses/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
#287 at /home/wmoses/.julia/packages/Zygote/SuKWp/src/lib/lib.jl:206 [inlined]
#2156#back at /home/wmoses/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71
unknown function (ip: 0x7f18a00db602)
_jl_invoke at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2731 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2913
Pullback at /home/wmoses/.julia/packages/DiffEqBase/ihYDa/src/solve.jl:832 [inlined]
Pullback at /home/wmoses/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
Pullback at /home/wmoses/git/Enzyme.jl/is692.jl:28 [inlined]
Pullback at /home/wmoses/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
Pullback at /home/wmoses/git/Enzyme.jl/is692.jl:27 [inlined]
Pullback at /home/wmoses/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
Pullback at /home/wmoses/git/Enzyme.jl/is692.jl:34 [inlined]
Pullback at /home/wmoses/git/Enzyme.jl/is692.jl:39 [inlined]
#287 at /home/wmoses/.julia/packages/Zygote/SuKWp/src/lib/lib.jl:206 [inlined]
#2156#back at /home/wmoses/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71 [inlined]
Pullback at /home/wmoses/.julia/packages/SciMLBase/VdcHg/src/scimlfunctions.jl:3626 [inlined]
#287 at /home/wmoses/.julia/packages/Zygote/SuKWp/src/lib/lib.jl:206 [inlined]
#2156#back at /home/wmoses/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71
unknown function (ip: 0x7f18a00d8846)
_jl_invoke at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2731 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2913
Pullback at /home/wmoses/.julia/packages/Optimization/vFala/src/function/zygote.jl:31 [inlined]
Pullback at /home/wmoses/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
#287 at /home/wmoses/.julia/packages/Zygote/SuKWp/src/lib/lib.jl:206 [inlined]
#2156#back at /home/wmoses/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71 [inlined]
Pullback at /home/wmoses/.julia/packages/Optimization/vFala/src/function/zygote.jl:35 [inlined]
Pullback at /home/wmoses/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
#75 at /home/wmoses/.julia/packages/Zygote/SuKWp/src/compiler/interface.jl:45
unknown function (ip: 0x7f18a00d4516)
_jl_invoke at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2731 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2913
gradient at /home/wmoses/.julia/packages/Zygote/SuKWp/src/compiler/interface.jl:97
#275 at /home/wmoses/.julia/packages/Optimization/vFala/src/function/zygote.jl:33
unknown function (ip: 0x7f18f0f998d6)
_jl_invoke at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2731 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2913
jl_apply at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/julia.h:1878 [inlined]
do_apply at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/builtins.c:730
macro expansion at /home/wmoses/.julia/packages/OptimizationOptimisers/FWIuf/src/OptimizationOptimisers.jl:31 [inlined]
macro expansion at /home/wmoses/.julia/packages/Optimization/vFala/src/utils.jl:37 [inlined]
#__solve#1 at /home/wmoses/.julia/packages/OptimizationOptimisers/FWIuf/src/OptimizationOptimisers.jl:30
__solve at /home/wmoses/.julia/packages/OptimizationOptimisers/FWIuf/src/OptimizationOptimisers.jl:7 [inlined]
__solve at /home/wmoses/.julia/packages/OptimizationOptimisers/FWIuf/src/OptimizationOptimisers.jl:7 [inlined]
#solve#553 at /home/wmoses/.julia/packages/SciMLBase/VdcHg/src/solve.jl:86 [inlined]
solve at /home/wmoses/.julia/packages/SciMLBase/VdcHg/src/solve.jl:80
unknown function (ip: 0x7f18f0f6b0cd)
_jl_invoke at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2731 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2913
jl_apply at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/julia.h:1878 [inlined]
do_call at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/interpreter.c:126
eval_value at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/interpreter.c:226
eval_stmt_value at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/interpreter.c:177 [inlined]
eval_body at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/interpreter.c:624
jl_interpret_toplevel_thunk at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/interpreter.c:762
jl_toplevel_eval_flex at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/toplevel.c:912
jl_toplevel_eval_flex at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/toplevel.c:856
ijl_toplevel_eval_in at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/toplevel.c:971
eval at ./boot.jl:370 [inlined]
include_string at ./loading.jl:1858
_jl_invoke at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2731 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2913
_include at ./loading.jl:1918
include at ./Base.jl:457
jfptr_include_30733.clone_1 at /home/wmoses/git/Enzyme.jl/julia-1.9.0-rc1/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2731 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2913
exec_options at ./client.jl:307
_start at ./client.jl:522
jfptr__start_33350.clone_1 at /home/wmoses/git/Enzyme.jl/julia-1.9.0-rc1/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2731 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2913
jl_apply at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/julia.h:1878 [inlined]
true_main at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/jlapi.c:573
jl_repl_entrypoint at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/jlapi.c:717
main at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/cli/loader_exe.c:59
unknown function (ip: 0x7f1a88c42d8f)
__libc_start_main at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
unknown function (ip: 0x401098)
Allocations: 396911310 (Pool: 396812581; Big: 98729); GC: 630
Segmentation fault (core dumped)
wsmoses commented 1 year ago

out6.txt

wsmoses commented 1 year ago

GC issue now fixed on main, and [once https://github.com/EnzymeAD/Enzyme.jl/pull/911 lands], output is now:

┌ Warning: Using fallback BLAS replacements, performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/cy24l/src/utils.jl:56
┌ Warning: Using fallback BLAS replacements, performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/cy24l/src/utils.jl:56
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│     on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│     consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│     on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│     consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│     on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│     consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│     on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│     consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│     on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│     consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│     on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│     consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│     on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│     consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│     on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│     consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│     on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│     consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│     on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│     consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│     on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│     consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│     on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│     consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│     on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│     consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│     on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│     consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│     on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│     consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│     on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│     consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│     on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│     consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│     on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│     consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│     on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│     consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│     on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│     consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
wsmoses commented 1 year ago

With the (currently off by default, but can be enabled with a flag) BLAS handling, there is no longer a performance warning from Enzyme (though there's still an unrelated diffcache one above).

Full code below:

using OrdinaryDiffEq, DiffEqFlux
using Optimization, OptimizationOptimisers, OptimizationOptimJL
using ComponentArrays, Lux, Plots, Random, StatsBase
using DelimitedFiles, Serialization
rng = Random.default_rng()

using Enzyme
Enzyme.API.runtimeActivity!(true)
Enzyme.Compiler.bitcode_replacement!(false)

neuralNetwork = Lux.Chain(Lux.Dense(2,10),Lux.Dense(10,5),Lux.Dense(5,2))
ps, st = Lux.setup(rng, neuralNetwork)
ps64 = Float64.(ComponentArray(ps))
knownDynamics(x,p,q) = [-p[1].*x[1];
                    -p[2].*x[2]];
training_data = rand(2,50)
ps_dynamics = ComponentArray((predefined_params = rand(Float64, 2), model_params = ps64))
function ude!(du,u,p,t,q)
        knownPred = knownDynamics(u,p.predefined_params,q)
        nnPred = Array(neuralNetwork(u,p.model_params,st)[1])

        for i in 1:length(u)
            du[i] = knownPred[i]+nnPred[i]
        end
end
nn_dynamics!(du,u,p,t) = ude!(du,u,p,t,nothing)
prob_nn = ODEProblem(nn_dynamics!,training_data[:, 1], (Float64(1),Float64(size(training_data,2))), ps_dynamics)
function predict(p, X = training_data[:,1], T = 1:size(training_data,2))
        _prob = remake(prob_nn, u0 = X, tspan = (Float64(T[1]), Float64(T[end])), p = p)
        Array(solve(_prob, Rodas4P(), saveat = T,
                abstol=1e-7, reltol=1e-7
                ))
end

function loss_function(p)
        X̂ = predict(p)
        sum(abs2, training_data .- X̂)
end
pinit = ComponentVector(ps_dynamics)
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_function(x), adtype)
optprob = Optimization.OptimizationProblem(optf, pinit)
result_neuralode = Optimization.solve(optprob,
                                           ADAM(),
                                           maxiters = 10)

That said your code itself was also type unstable in places that would be a performance bottleneck, consider changing it to something like:

using OrdinaryDiffEq, DiffEqFlux
using Optimization, OptimizationOptimisers, OptimizationOptimJL
using ComponentArrays, Lux, Plots, Random, StatsBase
using DelimitedFiles, Serialization
rng = Random.default_rng()

using Enzyme
# Enzyme.API.runtimeActivity!(true)
Enzyme.Compiler.bitcode_replacement!(false)

const neuralNetwork = Lux.Chain(Lux.Dense(2,10),Lux.Dense(10,5),Lux.Dense(5,2))

ltup = Lux.setup(rng, neuralNetwork)
const ps = ltup[1]
const st = ltup[2]

ps64 = Float64.(ComponentArray(ps))

knownDynamics(x,p,q) = [-p[1].*x[1];
                    -p[2].*x[2]];

const training_data = rand(2,50)

const ps_dynamics = ComponentArray((predefined_params = rand(Float64, 2), model_params = ps64))

function ude!(du,u,p,t,q)
        knownPred = knownDynamics(u,p.predefined_params,q)
        nnPred = Array(neuralNetwork(u,p.model_params,st)[1])

        for i in 1:length(u)
            du[i] = knownPred[i]+nnPred[i]
        end
end

nn_dynamics!(du,u,p,t) = ude!(du,u,p,t,nothing)

const prob_nn = ODEProblem(nn_dynamics!,training_data[:, 1], (Float64(1),Float64(size(training_data,2))), ps_dynamics)

function predict(p, X = training_data[:,1], T = 1:size(training_data,2))
        _prob = remake(prob_nn, u0 = X, tspan = (Float64(T[1]), Float64(T[end])), p = p)
        Array(solve(_prob, Rodas4P(), saveat = T,
                abstol=1e-7, reltol=1e-7
                ))
end

function loss_function(p)
        X̂ = predict(p)
        sum(abs2, training_data .- X̂)
end
pinit = ComponentVector(ps_dynamics)
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_function(x), adtype)
optprob = Optimization.OptimizationProblem(optf, pinit)
result_neuralode = Optimization.solve(optprob,
                                           ADAM(),
                                           maxiters = 10)