Closed jarroyoe closed 1 year ago
cc @ChrisRackauckas
THe latter warning implies you aren't using the latest main. What happens when you use the latest Enzyme?
Secondly, from AutoZygote
I'm not sure you're running with Enzyme ehre.
No, it's using Enzyme in one of the adjoints.
Though last I checked on main it'll still fail with Lux until https://github.com/EnzymeAD/Enzyme.jl/issues/645 is fixed. The getfield missing derivatives are kind of the last piece of the puzzle.
Yeah but I presume the Lux AD they are doing successfully isn't going through Enzyme, right? Thus the perf issue they see is Zygote?
No it would be the Enzyme BLAS fallback.
Hm okay, in that case -- so I can properly understand, can you make a version of this code that just is a call to Enzyme.autodiff ?
It would just be the UDE part:
function ude!(du,u,p,t,q)
knownPred = knownDynamics(u,p.predefined_params,q)
nnPred = Array(neuralNetwork(u,p.model_params,st)[1])
for i in 1:length(u)
du[i] = knownPred[i]+nnPred[i]
end
end
w.r.t. u,p,
Can you give an example that initializes the internals? Am not familiar with those packages.
using Enzyme, Random, Lux, ComponentArrays
rng = Random.default_rng()
neuralNetwork = Lux.Chain(Lux.Dense(2,10),Lux.Dense(10,5),Lux.Dense(5,2))
ps, st = Lux.setup(rng, neuralNetwork)
ps64 = Float64.(ComponentArray(ps))
knownDynamics(x,p,q) = [-p[1].*x[1];
-p[2].*x[2]];
training_data = rand(2,50)
ps_dynamics = ComponentArray((predefined_params = rand(Float64, 2), model_params = ps64))
function ude!(du,u,p,t,q)
knownPred = knownDynamics(u,p.predefined_params,q)
nnPred = Array(neuralNetwork(u,p.model_params,st)[1])
for i in 1:length(u)
du[i] = knownPred[i]+nnPred[i]
end
end
du = ?
d_du = ?
u = ?
d_u = ?
p = ?
d_p = ?
t = ?
q = ?
Enzyme.autodiff(Reverse, ude!, Duplicated(du, d_du), Duplicated(u, d_u), Const(p), Const(d_p), Const(t), Const(q))
using Enzyme, Random, Lux, ComponentArrays
rng = Random.default_rng()
neuralNetwork = Lux.Chain(Lux.Dense(2,10),Lux.Dense(10,5),Lux.Dense(5,2))
ps, st = Lux.setup(rng, neuralNetwork)
ps64 = Float64.(ComponentArray(ps))
knownDynamics(x,p) = [-p[1].*x[1];
-p[2].*x[2]];
training_data = rand(2,50)
ps_dynamics = ComponentArray((predefined_params = rand(Float64, 2), model_params = ps64))
function ude!(du,u,p,t)
knownPred = knownDynamics(u,p.predefined_params)
nnPred = Array(neuralNetwork(u,p.model_params,st)[1])
for i in 1:length(u)
du[i] = knownPred[i]+nnPred[i]
end
nothing
end
du = training_data[:,1]
d_du = training_data[:,1]
u = training_data[:,1]
d_u = training_data[:,1]
p = ps_dynamics
d_p = copy(ps_dynamics)
t = 0.0
ude!(du,u,p,t)
Enzyme.autodiff(Reverse, ude!, Duplicated(du, d_du), Duplicated(u, d_u), Const(p), Const(t))
Enzyme.autodiff(Reverse, ude!, Duplicated(du, d_du), Const(u), DuplicatedNoNeed(p, d_p), Const(t))
Hope this is not tangent to the specifics of Lux above, but back to original question regarding slow performance associated with BLAS fallback, this is something that is experiened in the following simple setup:
function mymatmul(x::AbstractMatrix, w::AbstractMatrix)
out = sum(w * x)
return out
end;
seed!(123)
bs = 4096
f = 256
h1 = 512
w = randn(h1, f) .* 0.01;
x = randn(f, bs) .* 0.01;
dw = zeros(h1, f);
# dx = zeros(f, bs);
Forward pass (0.01 sec)
julia> @time mymatmul(x, w)
0.012209 seconds (3 allocations: 16.000 MiB)
0.5842598323078428
Forward-backward (5 secs, over 100X slower)
julia> @time _, y = Enzyme.autodiff(ReverseWithPrimal, mymatmul, Const(x), Duplicated(w, dw))
5.066479 seconds (4.11 k allocations: 32.125 MiB)
((nothing, nothing), 0.5842598323078416)
This is from current main branch version:
[7da242da] Enzyme v0.11.0 `https://github.com/EnzymeAD/Enzyme.jl.git#main`
[f151be2c] EnzymeCore v0.2.1 `https://github.com/EnzymeAD/Enzyme.jl.git:lib/EnzymeCore#main`
@ChrisRackauckas yeah this does not differentiate with Enzyme atm so I'm presuming this is actually Zygote or a different fallback causing the issues.
No augmented forward pass found for ijl_box_char
declare nonnull {} addrspace(10)* @ijl_box_char(i32 zeroext) local_unnamed_addr #3
Stacktrace:
[1] julia_error(cstr::Cstring, val::Ptr{LLVM.API.LLVMOpaqueValue}, errtype::Enzyme.API.ErrorType, data::Ptr{Nothing})
@ Enzyme.Compiler ~/git/Enzyme.jl/src/compiler.jl:5040
[2] macro expansion
@ ./logging.jl:362 [inlined]
[3] macro expansion
@ ~/.julia/packages/GPUCompiler/anMCs/src/utils.jl:58 [inlined]
[4] inlining_policy
@ ~/git/Enzyme.jl/src/compiler/interpreter.jl:181
[5] EnzymeCreateAugmentedPrimal(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{Enzyme.API.CDIFFE_TYPE}, TA::Enzyme.TypeAnalysis, returnUsed::Bool, shadowReturnUsed::Bool, typeInfo::Enzyme.FnTypeInfo,
I'll go add the function now, regardless to get it closer to working on it -- but an FYI I don't think Enzyme is the cause.
Summarizing for @jarroyoe, this code does not currently differentiate with Enzyme, so Enzyme is not the cause of the performance (the warnings come from Enzyme be asked to differentiate something, though presumably not using that since Enzyme would error). Thus your performance question is best asked at the issue forum for a different library.
We'll work on making this differentiate so let's keep the issue open, regardless.
@jeremiedb suboptimal performance on fallback blas is expected (hence the warning about it). Feel free to open a separate issue to track. Resolving it requires either someone adding the BLAS rules in EnzymeRules or finishing up @ZuseZ4's work on Enzyme-internal blas rules.
@wsmoses thank you for this information. To summarize, is this issue entirely on the Enzyme.jl team hands now, or something can be done through the DiffEqFlux.jl part of the script?
You could define an EnzymeRule for the unsupported part of the code.
The current bottleneck is getfield: https://github.com/EnzymeAD/Enzyme.jl/issues/645, https://github.com/EnzymeAD/Enzyme.jl/issues/644. What's going on in the OP's case is that it will do a try/catch on Enzyme, which throws the warning, but then fails (errors), which is caught, and then it falls back to using ReverseDiff in scalar mode (to handle the mutation), and it should be doing tape compilation, but that's all a bit besides the point. The key is that what I showed is what it trys to do with Enzyme and fails.
using Enzyme, Random, Lux, ComponentArrays
Enzyme.API.printall!(true)
rng = Random.default_rng()
neuralNetwork = Lux.Chain(Lux.Dense(2,1),Lux.Dense(1,2))
ps, st = Lux.setup(rng, neuralNetwork)
ps64 = Float64.(ComponentArray(ps))
# @show ps64
function ude!(u,p, neuralNetwork)
neuralNetwork = Base.inferencebarrier(neuralNetwork.layers.layer_1)
# y = neuralNetwork(u, p.layer_1, NamedTuple()) # st.layer_1)
(y, _) = neuralNetwork(u, p.layer_1, NamedTuple()) # st.layer_1)
y[1]::Float64
end
u = Float64[0,0]
d_u = Float64[0,0]
p = ps64
d_p = copy(ps64)
ude!(u,p, neuralNetwork)
Enzyme.autodiff(Reverse, ude!, Duplicated(u, d_u), Const(p), neuralNetwork)
this hits a GC error atm, investigating.
The GC part of the error should now be fixed, the type instability persists.
Okay, still needs work:
using OrdinaryDiffEq, DiffEqFlux
using Optimization, OptimizationOptimisers, OptimizationOptimJL
using ComponentArrays, Lux, Random
# using ComponentArrays, Lux, Plots, Random, StatsBase
# using DelimitedFiles, Serialization
rng = Random.default_rng()
const neuralNetwork = Lux.Chain(Lux.Dense(2,10),Lux.Dense(10,5),Lux.Dense(5,2))
const ps, st = Lux.setup(rng, neuralNetwork)
ps64 = Float64.(ComponentArray(ps))
knownDynamics(x,p,q) = [-p[1].*x[1];
-p[2].*x[2]];
training_data = rand(2,50)
ps_dynamics = ComponentArray((predefined_params = rand(Float64, 2), model_params = ps64))
function ude!(du,u,p,t,q)
knownPred = knownDynamics(u,p.predefined_params,q)
nnPred = Array(neuralNetwork(u,p.model_params,st)[1])
for i in 1:length(u)
du[i] = knownPred[i]+nnPred[i]
end
end
nn_dynamics!(du,u,p,t) = ude!(du,u,p,t,nothing)
prob_nn = ODEProblem(nn_dynamics!,training_data[:, 1], (Float64(1),Float64(size(training_data,2))), ps_dynamics)
function predict(p, X = training_data[:,1], T = 1:size(training_data,2))
_prob = remake(prob_nn, u0 = X, tspan = (Float64(T[1]), Float64(T[end])), p = p)
Array(solve(_prob, Rodas4P(), saveat = T,
abstol=1e-7, reltol=1e-7
))
end
function loss_function(p)
X̂ = predict(p)
sum(abs2, training_data .- X̂)
end
pinit = ComponentVector(ps_dynamics)
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_function(x), adtype)
optprob = Optimization.OptimizationProblem(optf, pinit)
result_neuralode = Optimization.solve(optprob,
ADAM(),
maxiters = 10)
wmoses@beast:~/git/Enzyme.jl (mt) $ ./julia-1.9.0-rc1/bin/julia --project is692.jl
┌ Warning: Using fallback BLAS replacements, performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/BxfIW/src/utils.jl:56
warning: didn't implement memmove, using memcpy as fallback which can result in errors
┌ Warning: Using fallback BLAS replacements, performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/BxfIW/src/utils.jl:56
warning: didn't implement memmove, using memcpy as fallback which can result in errors
┌ Warning: Using fallback BLAS replacements, performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/BxfIW/src/utils.jl:56
warning: didn't implement memmove, using memcpy as fallback which can result in errors
warning: didn't implement memmove, using memcpy as fallback which can result in errors
[203832] signal (11.1): Segmentation fault
in expression starting at /home/wmoses/git/Enzyme.jl/is692.jl:41
unknown function (ip: 0x7f18ea5f5443)
unknown function (ip: 0x7f18ea5f4117)
unknown function (ip: 0x7f18ea5ff542)
macro expansion at /home/wmoses/git/Enzyme.jl/src/compiler.jl:8646 [inlined]
enzyme_call at /home/wmoses/git/Enzyme.jl/src/compiler.jl:8338 [inlined]
CombinedAdjointThunk at /home/wmoses/git/Enzyme.jl/src/compiler.jl:8301 [inlined]
autodiff at /home/wmoses/git/Enzyme.jl/src/Enzyme.jl:205 [inlined]
autodiff at /home/wmoses/git/Enzyme.jl/src/Enzyme.jl:228 [inlined]
autodiff at /home/wmoses/git/Enzyme.jl/src/Enzyme.jl:214 [inlined]
_vecjacobian! at /home/wmoses/.julia/packages/SciMLSensitivity/M3EmS/src/derivative_wrappers.jl:687
#vecjacobian!#28 at /home/wmoses/.julia/packages/SciMLSensitivity/M3EmS/src/derivative_wrappers.jl:224 [inlined]
vecjacobian! at /home/wmoses/.julia/packages/SciMLSensitivity/M3EmS/src/derivative_wrappers.jl:221 [inlined]
ODEInterpolatingAdjointSensitivityFunction at /home/wmoses/.julia/packages/SciMLSensitivity/M3EmS/src/interpolating_adjoint.jl:135
ODEFunction at /home/wmoses/.julia/packages/SciMLBase/VdcHg/src/scimlfunctions.jl:2126 [inlined]
ode_determine_initdt at /home/wmoses/.julia/packages/OrdinaryDiffEq/gjQVg/src/initdt.jl:53
auto_dt_reset! at /home/wmoses/.julia/packages/OrdinaryDiffEq/gjQVg/src/integrators/integrator_interface.jl:442 [inlined]
handle_dt! at /home/wmoses/.julia/packages/OrdinaryDiffEq/gjQVg/src/solve.jl:547
#__init#632 at /home/wmoses/.julia/packages/OrdinaryDiffEq/gjQVg/src/solve.jl:509
__init at /home/wmoses/.julia/packages/OrdinaryDiffEq/gjQVg/src/solve.jl:10 [inlined]
__init at /home/wmoses/.julia/packages/OrdinaryDiffEq/gjQVg/src/solve.jl:10 [inlined]
__init at /home/wmoses/.julia/packages/OrdinaryDiffEq/gjQVg/src/solve.jl:10 [inlined]
__init at /home/wmoses/.julia/packages/OrdinaryDiffEq/gjQVg/src/solve.jl:10 [inlined]
__init at /home/wmoses/.julia/packages/OrdinaryDiffEq/gjQVg/src/solve.jl:10 [inlined]
#__solve#631 at /home/wmoses/.julia/packages/OrdinaryDiffEq/gjQVg/src/solve.jl:5 [inlined]
__solve at /home/wmoses/.julia/packages/OrdinaryDiffEq/gjQVg/src/solve.jl:1 [inlined]
#solve_call#22 at /home/wmoses/.julia/packages/DiffEqBase/ihYDa/src/solve.jl:509
solve_call at /home/wmoses/.julia/packages/DiffEqBase/ihYDa/src/solve.jl:479
unknown function (ip: 0x7f18a0158328)
_jl_invoke at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2731 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2913
#solve_up#29 at /home/wmoses/.julia/packages/DiffEqBase/ihYDa/src/solve.jl:932
solve_up at /home/wmoses/.julia/packages/DiffEqBase/ihYDa/src/solve.jl:905 [inlined]
#solve#27 at /home/wmoses/.julia/packages/DiffEqBase/ihYDa/src/solve.jl:842
solve at /home/wmoses/.julia/packages/DiffEqBase/ihYDa/src/solve.jl:832
unknown function (ip: 0x7f18a0105686)
_jl_invoke at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2731 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2913
#_adjoint_sensitivities#68 at /home/wmoses/.julia/packages/SciMLSensitivity/M3EmS/src/sensitivity_interface.jl:433
_adjoint_sensitivities at /home/wmoses/.julia/packages/SciMLSensitivity/M3EmS/src/sensitivity_interface.jl:390 [inlined]
#adjoint_sensitivities#67 at /home/wmoses/.julia/packages/SciMLSensitivity/M3EmS/src/sensitivity_interface.jl:386 [inlined]
adjoint_sensitivities at /home/wmoses/.julia/packages/SciMLSensitivity/M3EmS/src/sensitivity_interface.jl:358
unknown function (ip: 0x7f18a00e0296)
_jl_invoke at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2731 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2913
adjoint_sensitivity_backpass at /home/wmoses/.julia/packages/SciMLSensitivity/M3EmS/src/concrete_solve.jl:523
ZBack at /home/wmoses/.julia/packages/Zygote/SuKWp/src/compiler/chainrules.jl:211 [inlined]
kw_zpullback at /home/wmoses/.julia/packages/Zygote/SuKWp/src/compiler/chainrules.jl:237
#287 at /home/wmoses/.julia/packages/Zygote/SuKWp/src/lib/lib.jl:206 [inlined]
#2156#back at /home/wmoses/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71
unknown function (ip: 0x7f18a00de402)
_jl_invoke at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2731 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2913
Pullback at /home/wmoses/.julia/packages/DiffEqBase/ihYDa/src/solve.jl:842 [inlined]
Pullback at /home/wmoses/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
#287 at /home/wmoses/.julia/packages/Zygote/SuKWp/src/lib/lib.jl:206 [inlined]
#2156#back at /home/wmoses/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71
unknown function (ip: 0x7f18a00db602)
_jl_invoke at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2731 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2913
Pullback at /home/wmoses/.julia/packages/DiffEqBase/ihYDa/src/solve.jl:832 [inlined]
Pullback at /home/wmoses/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
Pullback at /home/wmoses/git/Enzyme.jl/is692.jl:28 [inlined]
Pullback at /home/wmoses/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
Pullback at /home/wmoses/git/Enzyme.jl/is692.jl:27 [inlined]
Pullback at /home/wmoses/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
Pullback at /home/wmoses/git/Enzyme.jl/is692.jl:34 [inlined]
Pullback at /home/wmoses/git/Enzyme.jl/is692.jl:39 [inlined]
#287 at /home/wmoses/.julia/packages/Zygote/SuKWp/src/lib/lib.jl:206 [inlined]
#2156#back at /home/wmoses/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71 [inlined]
Pullback at /home/wmoses/.julia/packages/SciMLBase/VdcHg/src/scimlfunctions.jl:3626 [inlined]
#287 at /home/wmoses/.julia/packages/Zygote/SuKWp/src/lib/lib.jl:206 [inlined]
#2156#back at /home/wmoses/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71
unknown function (ip: 0x7f18a00d8846)
_jl_invoke at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2731 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2913
Pullback at /home/wmoses/.julia/packages/Optimization/vFala/src/function/zygote.jl:31 [inlined]
Pullback at /home/wmoses/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
#287 at /home/wmoses/.julia/packages/Zygote/SuKWp/src/lib/lib.jl:206 [inlined]
#2156#back at /home/wmoses/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71 [inlined]
Pullback at /home/wmoses/.julia/packages/Optimization/vFala/src/function/zygote.jl:35 [inlined]
Pullback at /home/wmoses/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
#75 at /home/wmoses/.julia/packages/Zygote/SuKWp/src/compiler/interface.jl:45
unknown function (ip: 0x7f18a00d4516)
_jl_invoke at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2731 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2913
gradient at /home/wmoses/.julia/packages/Zygote/SuKWp/src/compiler/interface.jl:97
#275 at /home/wmoses/.julia/packages/Optimization/vFala/src/function/zygote.jl:33
unknown function (ip: 0x7f18f0f998d6)
_jl_invoke at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2731 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2913
jl_apply at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/julia.h:1878 [inlined]
do_apply at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/builtins.c:730
macro expansion at /home/wmoses/.julia/packages/OptimizationOptimisers/FWIuf/src/OptimizationOptimisers.jl:31 [inlined]
macro expansion at /home/wmoses/.julia/packages/Optimization/vFala/src/utils.jl:37 [inlined]
#__solve#1 at /home/wmoses/.julia/packages/OptimizationOptimisers/FWIuf/src/OptimizationOptimisers.jl:30
__solve at /home/wmoses/.julia/packages/OptimizationOptimisers/FWIuf/src/OptimizationOptimisers.jl:7 [inlined]
__solve at /home/wmoses/.julia/packages/OptimizationOptimisers/FWIuf/src/OptimizationOptimisers.jl:7 [inlined]
#solve#553 at /home/wmoses/.julia/packages/SciMLBase/VdcHg/src/solve.jl:86 [inlined]
solve at /home/wmoses/.julia/packages/SciMLBase/VdcHg/src/solve.jl:80
unknown function (ip: 0x7f18f0f6b0cd)
_jl_invoke at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2731 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2913
jl_apply at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/julia.h:1878 [inlined]
do_call at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/interpreter.c:126
eval_value at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/interpreter.c:226
eval_stmt_value at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/interpreter.c:177 [inlined]
eval_body at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/interpreter.c:624
jl_interpret_toplevel_thunk at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/interpreter.c:762
jl_toplevel_eval_flex at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/toplevel.c:912
jl_toplevel_eval_flex at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/toplevel.c:856
ijl_toplevel_eval_in at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/toplevel.c:971
eval at ./boot.jl:370 [inlined]
include_string at ./loading.jl:1858
_jl_invoke at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2731 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2913
_include at ./loading.jl:1918
include at ./Base.jl:457
jfptr_include_30733.clone_1 at /home/wmoses/git/Enzyme.jl/julia-1.9.0-rc1/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2731 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2913
exec_options at ./client.jl:307
_start at ./client.jl:522
jfptr__start_33350.clone_1 at /home/wmoses/git/Enzyme.jl/julia-1.9.0-rc1/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2731 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/gf.c:2913
jl_apply at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/julia.h:1878 [inlined]
true_main at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/jlapi.c:573
jl_repl_entrypoint at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/src/jlapi.c:717
main at /cache/build/default-amdci5-5/julialang/julia-release-1-dot-9/cli/loader_exe.c:59
unknown function (ip: 0x7f1a88c42d8f)
__libc_start_main at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
unknown function (ip: 0x401098)
Allocations: 396911310 (Pool: 396812581; Big: 98729); GC: 630
Segmentation fault (core dumped)
GC issue now fixed on main, and [once https://github.com/EnzymeAD/Enzyme.jl/pull/911 lands], output is now:
┌ Warning: Using fallback BLAS replacements, performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/cy24l/src/utils.jl:56
┌ Warning: Using fallback BLAS replacements, performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/cy24l/src/utils.jl:56
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│ on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│ consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│ on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│ consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│ on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│ consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│ on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│ consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│ on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│ consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│ on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│ consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│ on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│ consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│ on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│ consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│ on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│ consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│ on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│ consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│ on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│ consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│ on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│ consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│ on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│ consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│ on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│ consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│ on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│ consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│ on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│ consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│ on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│ consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│ on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│ consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│ on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│ consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
┌ Warning: The supplied DiffCache was too small and was enlarged. This incurs allocations
│ on the first call to `get_tmp`. If few calls to `get_tmp` occur and optimal performance is essential,
│ consider changing 'N'/chunk size of this DiffCache to 12.
└ @ PreallocationTools ~/.julia/packages/PreallocationTools/nhCNl/src/PreallocationTools.jl:155
With the (currently off by default, but can be enabled with a flag) BLAS handling, there is no longer a performance warning from Enzyme (though there's still an unrelated diffcache one above).
Full code below:
using OrdinaryDiffEq, DiffEqFlux
using Optimization, OptimizationOptimisers, OptimizationOptimJL
using ComponentArrays, Lux, Plots, Random, StatsBase
using DelimitedFiles, Serialization
rng = Random.default_rng()
using Enzyme
Enzyme.API.runtimeActivity!(true)
Enzyme.Compiler.bitcode_replacement!(false)
neuralNetwork = Lux.Chain(Lux.Dense(2,10),Lux.Dense(10,5),Lux.Dense(5,2))
ps, st = Lux.setup(rng, neuralNetwork)
ps64 = Float64.(ComponentArray(ps))
knownDynamics(x,p,q) = [-p[1].*x[1];
-p[2].*x[2]];
training_data = rand(2,50)
ps_dynamics = ComponentArray((predefined_params = rand(Float64, 2), model_params = ps64))
function ude!(du,u,p,t,q)
knownPred = knownDynamics(u,p.predefined_params,q)
nnPred = Array(neuralNetwork(u,p.model_params,st)[1])
for i in 1:length(u)
du[i] = knownPred[i]+nnPred[i]
end
end
nn_dynamics!(du,u,p,t) = ude!(du,u,p,t,nothing)
prob_nn = ODEProblem(nn_dynamics!,training_data[:, 1], (Float64(1),Float64(size(training_data,2))), ps_dynamics)
function predict(p, X = training_data[:,1], T = 1:size(training_data,2))
_prob = remake(prob_nn, u0 = X, tspan = (Float64(T[1]), Float64(T[end])), p = p)
Array(solve(_prob, Rodas4P(), saveat = T,
abstol=1e-7, reltol=1e-7
))
end
function loss_function(p)
X̂ = predict(p)
sum(abs2, training_data .- X̂)
end
pinit = ComponentVector(ps_dynamics)
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_function(x), adtype)
optprob = Optimization.OptimizationProblem(optf, pinit)
result_neuralode = Optimization.solve(optprob,
ADAM(),
maxiters = 10)
That said your code itself was also type unstable in places that would be a performance bottleneck, consider changing it to something like:
using OrdinaryDiffEq, DiffEqFlux
using Optimization, OptimizationOptimisers, OptimizationOptimJL
using ComponentArrays, Lux, Plots, Random, StatsBase
using DelimitedFiles, Serialization
rng = Random.default_rng()
using Enzyme
# Enzyme.API.runtimeActivity!(true)
Enzyme.Compiler.bitcode_replacement!(false)
const neuralNetwork = Lux.Chain(Lux.Dense(2,10),Lux.Dense(10,5),Lux.Dense(5,2))
ltup = Lux.setup(rng, neuralNetwork)
const ps = ltup[1]
const st = ltup[2]
ps64 = Float64.(ComponentArray(ps))
knownDynamics(x,p,q) = [-p[1].*x[1];
-p[2].*x[2]];
const training_data = rand(2,50)
const ps_dynamics = ComponentArray((predefined_params = rand(Float64, 2), model_params = ps64))
function ude!(du,u,p,t,q)
knownPred = knownDynamics(u,p.predefined_params,q)
nnPred = Array(neuralNetwork(u,p.model_params,st)[1])
for i in 1:length(u)
du[i] = knownPred[i]+nnPred[i]
end
end
nn_dynamics!(du,u,p,t) = ude!(du,u,p,t,nothing)
const prob_nn = ODEProblem(nn_dynamics!,training_data[:, 1], (Float64(1),Float64(size(training_data,2))), ps_dynamics)
function predict(p, X = training_data[:,1], T = 1:size(training_data,2))
_prob = remake(prob_nn, u0 = X, tspan = (Float64(T[1]), Float64(T[end])), p = p)
Array(solve(_prob, Rodas4P(), saveat = T,
abstol=1e-7, reltol=1e-7
))
end
function loss_function(p)
X̂ = predict(p)
sum(abs2, training_data .- X̂)
end
pinit = ComponentVector(ps_dynamics)
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_function(x), adtype)
optprob = Optimization.OptimizationProblem(optf, pinit)
result_neuralode = Optimization.solve(optprob,
ADAM(),
maxiters = 10)
When I run this code:
I get repetitions of the following warnings, and performance is significantly slowed down.
However, if I change my neural network architecture to include a single hidden layer:
neuralNetwork = Lux.Chain(Lux.Dense(2,10),Lux.Dense(10,2))
The warnings disappear. The warnings also disappear when my neural network has a single input:
I haven't tested having a single layer neural network on production scale, but the case with a single input has a significantly better performance than the case with two inputs (the single input completes the production in less than a day, the double input takes several days to go through a single iteration of the for loop and runs out of 250GB of RAM).
Besides changing my neural networks to a single hidden layer (not ideal), how can this issue be fixed?