SciML / ModelingToolkitNeuralNets.jl

Symbolic-Numeric Universal Differential Equations for Automating Scientific Machine Learning (SciML)
MIT License
25 stars 1 forks source link

UDE as hierarchical component #36

Closed ggkountouras closed 1 month ago

ggkountouras commented 2 months ago

Describe the bug 🐞

When adapting the Friction example to a Diode-Capacitor model, I ran into ERROR: MethodError: Cannot convert an object of type SymbolicUtils.BasicSymbolic{Real} to an object of type Float64.

Expected behavior

It should work like the friction example.

Minimal Reproducible Example 👇

Details

```julia using ModelingToolkitNeuralNets using ModelingToolkit import ModelingToolkit.t_nounits as t import ModelingToolkit.D_nounits as Dt using ModelingToolkitStandardLibrary.Electrical using ModelingToolkitStandardLibrary.Blocks using OrdinaryDiffEq using Optimization using OptimizationOptimisers: Adam using SciMLStructures using SciMLStructures: Tunable using SymbolicIndexingInterface using StableRNGs using Lux using Plots function shockley_equation(v, Is, eta, Vt) Is*expm1(v/(eta*Vt)) end function ShockleyDiode(;name, Is = 1.0e-15, eta = 1.0, Vt = 26.0e-3) @named oneport = OnePort() @unpack v, i = oneport ps = @parameters Is=Is eta=eta Vt=Vt eqs = [ i ~ shockley_equation(v, Is, eta, Vt) ] extend(ODESystem(eqs, t, [], ps; name=name), oneport) end @named Vin = Voltage() @named Vin_V = Sine(amplitude=1.0, frequency=5.0) @named D = ShockleyDiode() @named C = Capacitor(C=1.0e-4) @named GND = Ground() diode_capacitor_eqs = [ connect(Vin_V.output, Vin.V) connect(Vin.p, D.p) connect(D.n, C.p) connect(Vin.n, C.n, GND.g) ] @named diode_capacitor_model_true = ODESystem(diode_capacitor_eqs, t, systems=[Vin, Vin_V, D, C, GND]) sys_true = structural_simplify(diode_capacitor_model_true) u0 = fill(0.0, length(unknowns(sys_true))) prob_true = ODEProblem(sys_true, u0, (0, 1.0), []) sol_ref = solve(prob_true, Rodas4(), abstol=1e-12, reltol=1e-7) # works fine plot(sol_ref, idxs=[Vin.v, C.v], xformatter=(_...) -> "") function diode_ude(;name) @named oneport = OnePort() @unpack v, i = oneport @named nn_in = RealInputArray(nin = 1) @named nn_out = RealOutputArray(nout = 1) chain = Lux.Chain( Lux.Dense(1 => 10, Lux.mish, use_bias = false), Lux.Dense(10 => 10, Lux.mish, use_bias = false), Lux.Dense(10 => 1, use_bias = false) ) @named nn = NeuralNetworkBlock(1, 1; chain = chain, rng = StableRNG(1111)) eqs = [ connect(nn_in, nn.output) connect(nn_out, nn.input) v ~ nn_out.u[1] i ~ nn_in.u[1] ] extend(ODESystem(eqs, t, [], []; name=name, systems=[nn_in, nn_out, nn]), oneport) end @named Vin2 = Voltage() @named Vin_V2 = Sine(amplitude=1.0, frequency=5.0) @named D2 = diode_ude(name=:diode_ude) @named C2 = Capacitor(C=1.0e-4) @named GND2 = Ground() diode_capacitor_eqs2 = [ connect(Vin_V2.output, Vin2.V) connect(Vin2.p, D2.p) connect(D2.n, C2.p) connect(Vin2.n, C2.n, GND2.g) ] @named diode_capacitor_model_ude = ODESystem(diode_capacitor_eqs2, t, systems=[Vin2, Vin_V2, D2, C2, GND2]) ude_sys = complete(diode_capacitor_model_ude) sys = structural_simplify(ude_sys) function loss(x, (prob, sol_ref, get_vars, get_refs)) new_p = SciMLStructures.replace(Tunable(), prob.p, x) new_prob = remake(prob, p = new_p, u0 = eltype(x).(prob.u0)) ts = sol_ref.t new_sol = solve(new_prob, Rodas4(), saveat = ts, abstol = 1e-12, reltol = 1e-7) loss = zero(eltype(x)) for i in eachindex(new_sol.u) loss += sum(abs2.(get_vars(new_sol, i) .- get_refs(sol_ref, i))) end if SciMLBase.successful_retcode(new_sol) loss else Inf end end of = OptimizationFunction{true}(loss, AutoForwardDiff()) prob = ODEProblem(sys, [0.0, 0.0], (0, 10.0), []) get_vars = getu(sys, [sys.C2.v]) get_refs = getu(sys_true, [sys_true.C.v]) x0 = reduce(vcat, getindex.((default_values(sys),), tunable_parameters(sys))) cb = (opt_state, loss) -> begin @info "step $(opt_state.iter), loss: $loss" return false end op = OptimizationProblem(of, x0, (prob, sol_ref, get_vars, get_refs)) res = solve(op, Adam(5e-3); maxiters = 10000, callback = cb) # this fails res_p = SciMLStructures.replace(Tunable(), prob.p, res) res_prob = remake(prob, p = res_p) res_sol = solve(res_prob, Rodas4(), saveat = sol_ref.t) initial_sol = solve(prob, Rodas4(), saveat = sol_ref.t) ```

Error & Stacktrace ⚠️

Details

```julia ERROR: MethodError: Cannot `convert` an object of type SymbolicUtils.BasicSymbolic{Real} to an object of type Float64 Closest candidates are: convert(::Type{Float64}, ::Measures.AbsoluteLength) @ Measures ~/.julia/packages/Measures/PKOxJ/src/length.jl:12 convert(::Type{T}, ::DualNumbers.Dual) where T<:Union{Real, Complex} @ DualNumbers ~/.julia/packages/DualNumbers/5knFX/src/dual.jl:24 convert(::Type{T}, ::VectorizationBase.LazyMulAdd{M, O, I}) where {M, O, I, T<:Number} @ VectorizationBase ~/.julia/packages/VectorizationBase/jSp7w/src/lazymul.jl:25 ... Stacktrace: [1] setindex!(A::Vector{Float64}, x::SymbolicUtils.BasicSymbolic{Real}, i1::Int64) @ Base ./array.jl:1021 [2] macro expansion @ ~/.julia/packages/SymbolicUtils/0opve/src/code.jl:418 [inlined] [3] macro expansion @ ~/.julia/packages/Symbolics/PAFGz/src/build_function.jl:546 [inlined] [4] macro expansion @ ~/.julia/packages/SymbolicUtils/0opve/src/code.jl:375 [inlined] [5] macro expansion @ ~/.julia/packages/RuntimeGeneratedFunctions/M9ZX8/src/RuntimeGeneratedFunctions.jl:163 [inlined] [6] macro expansion @ ./none:0 [inlined] [7] generated_callfunc @ ./none:0 [inlined] [8] (::RuntimeGeneratedFunctions.RuntimeGeneratedFunction{…})(::Vector{…}, ::Nothing, ::Vector{…}, ::Vector{…}, ::Vector{…}) @ RuntimeGeneratedFunctions ~/.julia/packages/RuntimeGeneratedFunctions/M9ZX8/src/RuntimeGeneratedFunctions.jl:150 [9] (::ModelingToolkit.var"#f#564"{…})(du::Vector{…}, u::Nothing, p::ModelingToolkit.MTKParameters{…}) @ ModelingToolkit ~/.julia/packages/ModelingToolkit/353ne/src/systems/nonlinear/nonlinearsystem.jl:292 [10] NonlinearFunction @ ~/.julia/packages/SciMLBase/sakPO/src/scimlfunctions.jl:2297 [inlined] [11] #build_null_solution#50 @ ~/.julia/packages/DiffEqBase/DS1sd/src/solve.jl:701 [inlined] [12] build_null_solution @ ~/.julia/packages/DiffEqBase/DS1sd/src/solve.jl:696 [inlined] [13] #solve_call#44 @ ~/.julia/packages/DiffEqBase/DS1sd/src/solve.jl:604 [inlined] [14] solve_call @ ~/.julia/packages/DiffEqBase/DS1sd/src/solve.jl:569 [inlined] [15] #solve_up#53 @ ~/.julia/packages/DiffEqBase/DS1sd/src/solve.jl:1072 [inlined] [16] solve_up @ ~/.julia/packages/DiffEqBase/DS1sd/src/solve.jl:1066 [inlined] [17] #solve#51 @ ~/.julia/packages/DiffEqBase/DS1sd/src/solve.jl:1003 [inlined] [18] solve @ ~/.julia/packages/DiffEqBase/DS1sd/src/solve.jl:993 [inlined] [19] _initialize_dae!(integrator::OrdinaryDiffEq.ODEIntegrator{…}, prob::ODEProblem{…}, alg::OrdinaryDiffEq.OverrideInit{…}, isinplace::Val{…}) @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/Knuk0/src/initialize_dae.jl:158 [20] _initialize_dae! @ ~/.julia/packages/OrdinaryDiffEq/Knuk0/src/initialize_dae.jl:87 [inlined] [21] initialize_dae! @ ~/.julia/packages/OrdinaryDiffEq/Knuk0/src/initialize_dae.jl:77 [inlined] [22] initialize_dae!(integrator::OrdinaryDiffEq.ODEIntegrator{…}) @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/Knuk0/src/initialize_dae.jl:77 [23] __init(prob::ODEProblem{…}, alg::Rodas4{…}, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}, recompile::Type{…}; saveat::Vector{…}, tstops::Tuple{}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_start::Bool, save_end::Nothing, callback::Nothing, dense::Bool, calck::Bool, dt::Float64, dtmin::Float64, dtmax::Float64, force_dtmin::Bool, adaptive::Bool, gamma::Rational{…}, abstol::Float64, reltol::Float64, qmin::Rational{…}, qmax::Int64, qsteady_min::Int64, qsteady_max::Rational{…}, beta1::Nothing, beta2::Nothing, qoldinit::Rational{…}, controller::Nothing, fullnormalize::Bool, failfactor::Int64, maxiters::Int64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), internalopnorm::typeof(LinearAlgebra.opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), progress_id::Symbol, userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias_u0::Bool, alias_du0::Bool, initializealg::OrdinaryDiffEq.DefaultInit, kwargs::@Kwargs{}) @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/Knuk0/src/solve.jl:502 [24] __init (repeats 4 times) @ ~/.julia/packages/OrdinaryDiffEq/Knuk0/src/solve.jl:11 [inlined] [25] __solve(::ODEProblem{…}, ::Rodas4{…}; kwargs::@Kwargs{…}) @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/Knuk0/src/solve.jl:6 [26] __solve @ ~/.julia/packages/OrdinaryDiffEq/Knuk0/src/solve.jl:1 [inlined] [27] solve_call(_prob::ODEProblem{…}, args::Rodas4{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{…}) @ DiffEqBase ~/.julia/packages/DiffEqBase/DS1sd/src/solve.jl:612 [28] solve_call @ ~/.julia/packages/DiffEqBase/DS1sd/src/solve.jl:569 [inlined] [29] #solve_up#53 @ ~/.julia/packages/DiffEqBase/DS1sd/src/solve.jl:1080 [inlined] [30] solve_up @ ~/.julia/packages/DiffEqBase/DS1sd/src/solve.jl:1066 [inlined] [31] solve(prob::ODEProblem{…}, args::Rodas4{…}; sensealg::Nothing, u0::Nothing, p::Nothing, wrap::Val{…}, kwargs::@Kwargs{…}) @ DiffEqBase ~/.julia/packages/DiffEqBase/DS1sd/src/solve.jl:1003 [32] solve @ ~/.julia/packages/DiffEqBase/DS1sd/src/solve.jl:993 [inlined] [33] loss(x::Vector{…}, ::Tuple{…}) @ Main ~/Documents/dev/julia-MTK-NN/D-C_UDEs.jl:88 [34] (::OptimizationForwardDiffExt.var"#37#55"{…})(::Vector{…}) @ OptimizationForwardDiffExt ~/.julia/packages/OptimizationBase/32Mb0/ext/OptimizationForwardDiffExt.jl:98 [35] #39 @ ~/.julia/packages/OptimizationBase/32Mb0/ext/OptimizationForwardDiffExt.jl:102 [inlined] [36] chunk_mode_gradient!(result::Vector{…}, f::OptimizationForwardDiffExt.var"#39#57"{…}, x::Vector{…}, cfg::ForwardDiff.GradientConfig{…}) @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/gradient.jl:123 [37] gradient! @ ~/.julia/packages/ForwardDiff/PcZ48/src/gradient.jl:39 [inlined] [38] (::OptimizationForwardDiffExt.var"#38#56"{…})(::Vector{…}, ::Vector{…}) @ OptimizationForwardDiffExt ~/.julia/packages/OptimizationBase/32Mb0/ext/OptimizationForwardDiffExt.jl:102 [39] macro expansion @ ~/.julia/packages/OptimizationOptimisers/AOkbT/src/OptimizationOptimisers.jl:68 [inlined] [40] macro expansion @ ~/.julia/packages/Optimization/ucp7G/src/utils.jl:32 [inlined] [41] __solve(cache::OptimizationCache{…}) @ OptimizationOptimisers ~/.julia/packages/OptimizationOptimisers/AOkbT/src/OptimizationOptimisers.jl:66 [42] solve!(cache::OptimizationCache{…}) @ SciMLBase ~/.julia/packages/SciMLBase/sakPO/src/solve.jl:188 [43] solve(::OptimizationProblem{…}, ::Adam; kwargs::@Kwargs{…}) @ SciMLBase ~/.julia/packages/SciMLBase/sakPO/src/solve.jl:96 [44] top-level scope @ ~/Documents/dev/julia-MTK-NN/D-C_UDEs.jl:113 Some type information was truncated. Use `show(err)` to see complete types. ```

Environment:

Details

```julia Status `~/Documents/dev/julia-MTK-NN/Project.toml` [0c46a032] DifferentialEquations v7.13.0 ⌃ [b2108857] Lux v0.5.56 ⌃ [961ee093] ModelingToolkit v9.19.0 ⌃ [f162e290] ModelingToolkitNeuralNets v1.0.2 ⌃ [16a59e39] ModelingToolkitStandardLibrary v2.7.2 ⌃ [7f7a1694] Optimization v3.26.1 [42dfb2eb] OptimizationOptimisers v0.2.1 ⌃ [1dea7af3] OrdinaryDiffEq v6.84.0 ⌃ [91a5bcdd] Plots v1.40.4 ⌃ [53ae85a6] SciMLStructures v1.3.0 [860ef19b] StableRNGs v1.0.2 ⌃ [c3572dad] Sundials v4.24.0 ⌃ [2efcf032] SymbolicIndexingInterface v0.3.22 Info Packages marked with ⌃ have new versions available and may be upgradable. ```

Details

```julia Status `~/Documents/dev/julia-MTK-NN/Manifest.toml` ⌃ [47edcb42] ADTypes v1.4.0 [1520ce14] AbstractTrees v0.4.5 ⌃ [7d9f7c33] Accessors v0.1.36 [79e6a3ab] Adapt v4.0.4 [66dad0bd] AliasTables v1.1.3 [dce04be8] ArgCheck v2.3.0 [ec485272] ArnoldiMethod v0.4.0 ⌃ [4fba245c] ArrayInterface v7.11.0 ⌃ [4c555306] ArrayLayouts v1.9.4 [a9b6321e] Atomix v0.1.0 ⌃ [aae01518] BandedMatrices v1.7.2 ⌃ [e2ed5e7c] Bijections v0.1.6 [d1d4a3ce] BitFlags v0.1.9 ⌃ [62783981] BitTwiddlingConvenienceFunctions v0.1.5 ⌃ [764a87c0] BoundaryValueDiffEq v5.8.0 [fa961155] CEnum v0.5.0 ⌃ [2a0fbf3d] CPUSummary v0.2.5 [00ebfdb7] CSTParser v3.4.3 [49dc2e85] Calculus v0.5.1 [d360d2e6] ChainRulesCore v1.24.0 ⌃ [fb6a15b2] CloseOpenIntervals v0.1.12 ⌃ [944b1d66] CodecZlib v0.7.4 ⌃ [35d6a980] ColorSchemes v3.25.0 [3da002f7] ColorTypes v0.11.5 [c3611d14] ColorVectorSpace v0.10.0 [5ae59095] Colors v0.12.11 [861a8166] Combinatorics v1.0.2 [a80b9123] CommonMark v0.8.12 [38540f10] CommonSolve v0.2.4 [bbf7d656] CommonSubexpressions v0.3.0 ⌃ [34da2185] Compat v4.15.0 ⌃ [b0b7db55] ComponentArrays v0.15.13 [b152e2b5] CompositeTypes v0.1.4 [a33af91c] CompositionsBase v0.1.2 [2569d6c7] ConcreteStructs v0.2.3 ⌃ [f0e56b4a] ConcurrentUtilities v2.4.1 [88cd18e8] ConsoleProgressMonitor v0.1.2 ⌃ [187b0558] ConstructionBase v1.5.5 [d38c429a] Contour v0.6.3 [adafc99b] CpuId v0.3.1 [a8cc5b0e] Crayons v4.1.1 [9a962f9c] DataAPI v1.16.0 [864edb3b] DataStructures v0.18.20 [e2d170a0] DataValueInterfaces v1.0.0 ⌃ [bcd4f6db] DelayDiffEq v5.47.3 [8bb1440f] DelimitedFiles v1.9.1 ⌃ [2b5f629d] DiffEqBase v6.151.4 ⌃ [459566f4] DiffEqCallbacks v3.6.2 ⌃ [77a26b50] DiffEqNoiseProcess v5.21.0 [163ba53b] DiffResults v1.1.0 [b552c78f] DiffRules v1.15.1 [0c46a032] DifferentialEquations v7.13.0 ⌃ [a0c0ee7d] DifferentiationInterface v0.5.5 [b4f34e82] Distances v0.10.11 ⌃ [31c24e10] Distributions v0.25.109 [ffbed154] DocStringExtensions v0.9.3 [5b8099bc] DomainSets v0.7.14 [fa6b7ba4] DualNumbers v0.6.8 ⌅ [7c1d4256] DynamicPolynomials v0.5.7 ⌅ [06fc5a27] DynamicQuantities v0.13.2 [4e289a0a] EnumX v1.0.4 ⌃ [f151be2c] EnzymeCore v0.7.6 [460bff9d] ExceptionUnwrapping v0.1.10 [d4d017d3] ExponentialUtilities v1.26.1 [e2ba6199] ExprTools v0.1.10 [c87230d0] FFMPEG v0.4.1 ⌃ [9d29842c] FastAlmostBandedMatrices v0.1.2 ⌃ [7034ab61] FastBroadcast v0.3.3 [9aa1b823] FastClosures v0.3.2 [29a986be] FastLapackInterface v2.0.4 ⌃ [1a297f60] FillArrays v1.11.0 ⌃ [64ca27bc] FindFirstFunctions v1.2.0 ⌃ [6a86dc24] FiniteDiff v2.23.1 [53c48c17] FixedPointNumbers v0.8.5 [1fa38f19] Format v1.3.7 [f6369f11] ForwardDiff v0.10.36 [069b7b12] FunctionWrappers v1.1.3 [77dc65aa] FunctionWrappersWrappers v0.1.3 ⌃ [d9f16b24] Functors v0.4.11 [46192b85] GPUArraysCore v0.1.6 ⌃ [28b8d3ca] GR v0.73.6 [c145ed77] GenericSchur v0.5.4 [c27321d9] Glob v1.3.1 ⌃ [86223c79] Graphs v1.11.1 [42e2da0e] Grisu v1.0.2 [cd3eb016] HTTP v1.10.8 ⌃ [3e5b6fbb] HostCPUFeatures v0.1.16 ⌃ [34004b35] HypergeometricFunctions v0.3.23 [615f187c] IfElse v0.1.1 [d25df0c9] Inflate v0.1.5 [8197267c] IntervalSets v0.7.10 ⌃ [3587e190] InverseFunctions v0.1.14 [92d709cd] IrrationalConstants v0.2.2 [82899510] IteratorInterfaceExtensions v1.0.0 [1019f520] JLFzf v0.1.7 [692b3bcd] JLLWrappers v1.5.0 [682c06a0] JSON v0.21.4 ⌃ [98e50ef6] JuliaFormatter v1.0.56 ⌃ [ccbc3e58] JumpProcesses v9.11.1 [ef3ab10e] KLU v0.6.0 ⌃ [63c18a36] KernelAbstractions v0.9.21 [ba0b0d4f] Krylov v0.9.6 [5be7bae1] LBFGSB v0.4.1 ⌅ [929cbde3] LLVM v7.2.1 [b964fa9f] LaTeXStrings v1.3.1 [2ee39098] LabelledArrays v1.16.0 [984bce1d] LambertW v0.4.6 ⌃ [23fbe1c1] Latexify v0.16.3 ⌃ [10f19ff3] LayoutPointers v0.1.15 ⌃ [5078a376] LazyArrays v2.0.5 [1d6d02ad] LeftChildRightSiblingTrees v0.2.0 [2d8b4e74] LevyArea v1.0.0 ⌃ [d3d80556] LineSearches v7.2.0 ⌃ [7ed4a6bd] LinearSolve v2.30.1 [2ab3a3ac] LogExpFunctions v0.3.28 [e6f89c97] LoggingExtras v1.0.3 ⌃ [bdcacae8] LoopVectorization v0.12.170 [30fc2ffe] LossFunctions v0.11.1 ⌃ [b2108857] Lux v0.5.56 ⌅ [bb33d45b] LuxCore v0.1.15 ⌃ [34f89e08] LuxDeviceUtils v0.1.23 ⌃ [82251201] LuxLib v0.3.27 [d8e11817] MLStyle v0.4.17 [1914dd2f] MacroTools v0.5.13 [d125e4d3] ManualMemory v0.1.8 ⌃ [a3b82374] MatrixFactorizations v3.0.0 [bb5d69b7] MaybeInplace v0.1.3 [739be429] MbedTLS v1.1.9 [442fdcdd] Measures v0.3.2 [e1d29d7a] Missings v1.2.0 ⌃ [961ee093] ModelingToolkit v9.19.0 ⌃ [f162e290] ModelingToolkitNeuralNets v1.0.2 ⌃ [16a59e39] ModelingToolkitStandardLibrary v2.7.2 [46d2c3a1] MuladdMacro v0.2.4 [102ac46a] MultivariatePolynomials v0.5.6 ⌃ [d8a4904e] MutableArithmetics v1.4.5 [d41bc354] NLSolversBase v7.8.3 [2774e3e8] NLsolve v4.5.1 ⌃ [872c559c] NNlib v0.9.18 [77ba4419] NaNMath v1.0.2 ⌃ [8913a72c] NonlinearSolve v3.13.0 ⌃ [6fe1bfb0] OffsetArrays v1.14.0 [4d8831e6] OpenSSL v1.4.3 [429524aa] Optim v1.9.4 [3bd65402] Optimisers v0.3.3 ⌃ [7f7a1694] Optimization v3.26.1 ⌃ [bca83a33] OptimizationBase v1.2.0 [42dfb2eb] OptimizationOptimisers v0.2.1 [bac558e1] OrderedCollections v1.6.3 ⌃ [1dea7af3] OrdinaryDiffEq v6.84.0 [90014a1f] PDMats v0.11.31 [65ce6f38] PackageExtensionCompat v1.0.2 [d96e819e] Parameters v0.12.3 [69de0a69] Parsers v2.8.1 [570af359] PartialFunctions v1.2.0 [b98c9c47] Pipe v1.3.0 [ccf2f8ad] PlotThemes v3.2.0 [995b91a9] PlotUtils v1.4.1 ⌃ [91a5bcdd] Plots v1.40.4 [e409e4f3] PoissonRandom v0.4.4 ⌃ [f517fe37] Polyester v0.7.14 ⌃ [1d0040c9] PolyesterWeave v0.2.1 [85a6dd25] PositiveFactorizations v0.2.4 ⌃ [d236fae5] PreallocationTools v0.4.22 [aea7be01] PrecompileTools v1.2.1 [21216c6a] Preferences v1.4.3 [33c8b6b6] ProgressLogging v0.1.4 ⌃ [92933f4c] ProgressMeter v1.10.0 [43287f4e] PtrArrays v1.2.0 ⌃ [1fd47b50] QuadGK v2.9.4 [74087812] Random123 v1.7.0 ⌃ [e6cf234a] RandomNumbers v1.5.3 [3cdcf5f2] RecipesBase v1.3.4 [01d81517] RecipesPipeline v0.6.12 ⌃ [731186ca] RecursiveArrayTools v3.23.1 [f2c3362d] RecursiveFactorization v0.2.23 [189a3867] Reexport v1.2.2 [05181044] RelocatableFolders v1.0.1 [ae029012] Requires v1.3.0 [ae5879a3] ResettableStacks v1.1.1 [79098fc4] Rmath v0.7.1 [7e49a35a] RuntimeGeneratedFunctions v0.5.13 [94e857df] SIMDTypes v0.1.0 ⌃ [476501e8] SLEEFPirates v0.6.42 ⌃ [0bca4576] SciMLBase v2.41.3 ⌃ [c0aeaf25] SciMLOperators v0.3.8 ⌃ [53ae85a6] SciMLStructures v1.3.0 [6c6a2e73] Scratch v1.2.1 [efcf1570] Setfield v1.1.1 [992d4aef] Showoff v1.0.3 [777ac1f9] SimpleBufferStream v1.1.0 ⌃ [727e6d20] SimpleNonlinearSolve v1.10.0 [699a6c99] SimpleTraits v0.9.4 [ce78b400] SimpleUnPack v1.1.0 [a2af1166] SortingAlgorithms v1.2.1 ⌃ [47a9eef4] SparseDiffTools v2.19.0 ⌅ [0a514795] SparseMatrixColorings v0.3.3 [e56a9233] Sparspak v0.3.9 [276daf66] SpecialFunctions v2.4.0 [860ef19b] StableRNGs v1.0.2 ⌅ [aedffcd0] Static v0.8.10 ⌃ [0d7ed370] StaticArrayInterface v1.5.0 ⌃ [90137ffa] StaticArrays v1.9.5 [1e83bf80] StaticArraysCore v1.4.3 [82ae8749] StatsAPI v1.7.0 [2913bbd2] StatsBase v0.34.3 [4c63d2b9] StatsFuns v1.3.1 ⌃ [9672c7b4] SteadyStateDiffEq v2.2.0 ⌃ [789caeaf] StochasticDiffEq v6.65.1 ⌃ [7792a7ef] StrideArraysCore v0.5.6 ⌃ [c3572dad] Sundials v4.24.0 ⌃ [2efcf032] SymbolicIndexingInterface v0.3.22 ⌃ [19f23fe9] SymbolicLimits v0.2.1 ⌅ [d1185830] SymbolicUtils v2.0.2 ⌅ [0c5d862f] Symbolics v5.30.3 [3783bdb8] TableTraits v1.0.1 ⌃ [bd369af6] Tables v1.11.1 [62fd8b95] TensorCore v0.1.1 ⌅ [8ea1fca8] TermInterface v0.4.1 [5d786b92] TerminalLoggers v0.1.7 [8290d209] ThreadingUtilities v0.5.2 [a759f4b9] TimerOutputs v0.5.24 [0796e94c] Tokenize v0.5.29 ⌅ [3bb67fe8] TranscodingStreams v0.10.9 ⌃ [d5829a12] TriangularSolve v0.2.0 ⌃ [410a4b4d] Tricks v0.1.8 [781d530d] TruncatedStacktraces v1.4.0 [5c2747f8] URIs v1.5.1 [3a884ed6] UnPack v1.0.2 [1cfade01] UnicodeFun v0.4.1 ⌃ [1986cc42] Unitful v1.20.0 ⌃ [45397f5d] UnitfulLatexify v1.6.3 [a7c27f48] Unityper v0.1.6 [013be700] UnsafeAtomics v0.2.1 ⌅ [d80eeb9a] UnsafeAtomicsLLVM v0.1.4 [41fe7b60] Unzip v0.2.0 ⌃ [3d5dd08c] VectorizationBase v0.21.68 [19fa3120] VertexSafeGraphs v0.2.0 ⌅ [d49dbf32] WeightInitializers v0.1.7 [6e34b625] Bzip2_jll v1.0.8+1 [83423d85] Cairo_jll v1.18.0+2 [2702e6a9] EpollShim_jll v0.0.20230411+0 [2e619515] Expat_jll v2.6.2+0 ⌅ [b22a6f82] FFMPEG_jll v4.4.4+1 [a3f928ae] Fontconfig_jll v2.13.96+0 [d7e528f0] FreeType2_jll v2.13.2+0 [559328eb] FriBidi_jll v1.0.14+0 ⌃ [0656b61e] GLFW_jll v3.3.9+0 ⌅ [d2c73de3] GR_jll v0.73.6+0 [78b55507] Gettext_jll v0.21.0+0 [7746bdde] Glib_jll v2.80.2+0 [3b182d85] Graphite2_jll v1.3.14+0 [2e76f6c2] HarfBuzz_jll v2.8.1+1 ⌃ [1d5cc7b8] IntelOpenMP_jll v2024.1.0+0 [aacddb02] JpegTurbo_jll v3.0.3+0 [c1c5ebd0] LAME_jll v3.100.2+0 ⌅ [88015f11] LERC_jll v3.0.0+1 ⌅ [dad2f222] LLVMExtra_jll v0.0.29+0 ⌃ [1d63c593] LLVMOpenMP_jll v15.0.7+0 [dd4b983a] LZO_jll v2.10.2+0 [81d17ec3] L_BFGS_B_jll v3.0.1+0 ⌅ [e9f186c6] Libffi_jll v3.2.2+1 [d4300ac3] Libgcrypt_jll v1.8.11+0 [7e76a0d4] Libglvnd_jll v1.6.0+0 [7add5ba3] Libgpg_error_jll v1.49.0+0 [94ce4f54] Libiconv_jll v1.17.0+0 [4b2f31a3] Libmount_jll v2.40.1+0 ⌅ [89763e89] Libtiff_jll v4.5.1+1 [38a345b3] Libuuid_jll v2.40.1+0 ⌃ [856f044c] MKL_jll v2024.1.0+0 [e7412a2a] Ogg_jll v1.3.5+1 [458c3c95] OpenSSL_jll v3.0.14+0 [efe28fd5] OpenSpecFun_jll v0.5.5+0 [91d4177d] Opus_jll v1.3.2+0 [30392449] Pixman_jll v0.43.4+0 [c0090381] Qt6Base_jll v6.7.1+1 ⌅ [f50d1b31] Rmath_jll v0.4.2+0 ⌅ [fb77eaff] Sundials_jll v5.2.2+0 [a44049a8] Vulkan_Loader_jll v1.3.243+0 [a2964d1f] Wayland_jll v1.21.0+1 [2381bf8a] Wayland_protocols_jll v1.31.0+0 ⌃ [02c8fc9c] XML2_jll v2.12.7+0 ⌃ [aed1982a] XSLT_jll v1.1.34+0 [ffd25f8a] XZ_jll v5.4.6+0 [f67eecfb] Xorg_libICE_jll v1.1.1+0 [c834827a] Xorg_libSM_jll v1.2.4+0 [4f6342f7] Xorg_libX11_jll v1.8.6+0 [0c0b7dd1] Xorg_libXau_jll v1.0.11+0 [935fb764] Xorg_libXcursor_jll v1.2.0+4 [a3789734] Xorg_libXdmcp_jll v1.1.4+0 [1082639a] Xorg_libXext_jll v1.3.6+0 [d091e8ba] Xorg_libXfixes_jll v5.0.3+4 [a51aa0fd] Xorg_libXi_jll v1.7.10+4 [d1454406] Xorg_libXinerama_jll v1.1.4+4 [ec84b674] Xorg_libXrandr_jll v1.5.2+4 [ea2f1a96] Xorg_libXrender_jll v0.9.11+0 [14d82f49] Xorg_libpthread_stubs_jll v0.1.1+0 ⌃ [c7cfdc94] Xorg_libxcb_jll v1.15.0+0 [cc61e674] Xorg_libxkbfile_jll v1.1.2+0 [e920d4aa] Xorg_xcb_util_cursor_jll v0.1.4+0 [12413925] Xorg_xcb_util_image_jll v0.4.0+1 [2def613f] Xorg_xcb_util_jll v0.4.0+1 [975044d2] Xorg_xcb_util_keysyms_jll v0.4.0+1 [0d47668e] Xorg_xcb_util_renderutil_jll v0.3.9+1 [c22f9ab0] Xorg_xcb_util_wm_jll v0.4.1+1 [35661453] Xorg_xkbcomp_jll v1.4.6+0 [33bec58e] Xorg_xkeyboard_config_jll v2.39.0+0 [c5fb5394] Xorg_xtrans_jll v1.5.0+0 [3161d3a3] Zstd_jll v1.5.6+0 [35ca27e7] eudev_jll v3.2.9+0 ⌅ [214eeab7] fzf_jll v0.43.0+0 [1a1c6b14] gperf_jll v3.1.1+0 [a4ae2306] libaom_jll v3.9.0+0 [0ac62f75] libass_jll v0.15.1+0 [2db6ffa8] libevdev_jll v1.11.0+0 [f638f0a6] libfdk_aac_jll v2.0.2+0 [36db933b] libinput_jll v1.18.0+0 [b53b4c65] libpng_jll v1.6.43+1 ⌃ [f27f6e37] libvorbis_jll v1.3.7+1 [009596ad] mtdev_jll v1.1.6+0 [1317d2d5] oneTBB_jll v2021.12.0+0 [1270edf5] x264_jll v2021.5.5+0 [dfaa095f] x265_jll v3.5.0+0 [d8fb68d0] xkbcommon_jll v1.4.1+1 [0dad84c5] ArgTools v1.1.1 [56f22d72] Artifacts [2a0f44e3] Base64 [ade2ca70] Dates [8ba89e20] Distributed [f43a241f] Downloads v1.6.0 [7b1f6079] FileWatching [9fa8497b] Future [b77e0a4c] InteractiveUtils [4af54fe1] LazyArtifacts [b27032c2] LibCURL v0.6.4 [76f85450] LibGit2 [8f399da3] Libdl [37e2e46d] LinearAlgebra [56ddb016] Logging [d6f4376e] Markdown [a63ad114] Mmap [ca575930] NetworkOptions v1.2.0 [44cfe95a] Pkg v1.10.0 [de0858da] Printf [3fa0cd96] REPL [9a3f8284] Random [ea8e919c] SHA v0.7.0 [9e88b42a] Serialization [1a1011a3] SharedArrays [6462fe0b] Sockets [2f01184e] SparseArrays v1.10.0 [10745b16] Statistics v1.10.0 [4607b0f0] SuiteSparse [fa267f1f] TOML v1.0.3 [a4e569a6] Tar v1.10.0 [8dfed614] Test [cf7118a7] UUIDs [4ec0a83e] Unicode [e66e0078] CompilerSupportLibraries_jll v1.1.1+0 [deac9b47] LibCURL_jll v8.4.0+0 [e37daf67] LibGit2_jll v1.6.4+0 [29816b5a] LibSSH2_jll v1.11.0+1 [c8ffd9c3] MbedTLS_jll v2.28.2+1 [14a3606d] MozillaCACerts_jll v2023.1.10 [4536629a] OpenBLAS_jll v0.3.23+4 [05823500] OpenLibm_jll v0.8.1+2 [efcefdf7] PCRE2_jll v10.42.0+1 [bea87d4a] SuiteSparse_jll v7.2.1+1 [83775a58] Zlib_jll v1.2.13+1 [8e850b90] libblastrampoline_jll v5.8.0+1 [8e850ede] nghttp2_jll v1.52.0+1 [3f19e933] p7zip_jll v17.4.0+2 Info Packages marked with ⌃ and ⌅ have new versions available. Those with ⌃ may be upgradable, but those with ⌅ are restricted by compatibility constraints from upgrading. To see why use `status --outdated -m ```

Details

```julia Julia Version 1.10.4 Commit 48d4fd48430 (2024-06-04 10:41 UTC) Build Info: Official https://julialang.org/ release Platform Info: OS: macOS (arm64-apple-darwin22.4.0) CPU: 10 × Apple M1 Max WORD_SIZE: 64 LIBM: libopenlibm LLVM: libLLVM-15.0.7 (ORCJIT, apple-m1) Threads: 1 default, 0 interactive, 1 GC (on 8 virtual cores) Environment: JULIA_EDITOR = code JULIA_NUM_THREADS = ```

SebastianM-C commented 2 months ago

It looks like you have quite old versions of the packages. Can you try an ]up? I also tried to reproduce the error with newer package versions and I can't reproduce it, but I'm getting instabilities in the optimization function when using the initial condition.

julia> of(x0, (prob, sol_ref, get_vars, get_refs))
┌ Warning: At t=0.0, dt was forced below floating point epsilon 5.0e-324, and step error estimate = NaN. Aborting. There is either an error in your model specification or the true solution is unstable (or the true solution can not be represented in the precision of Float64).
└ @ SciMLBase ~/.julia/packages/SciMLBase/HReyK/src/integrator_interface.jl:600
Inf

If you build a parameter vector that zeros out the neural network parameters, you can check if the instability is related to the neural network or not. In this case, I tried that and it seems that the system is unstable even with the neural network not doing anything.

While the neural network could theoretically have parameters that make the model stable, it's very hard to train it if the starting point is unstable, since that's infinite loss and doesn't inform the optimizer in any way. In order to have a more manageable problem, the starting point should have some non-fininte loss value.

ggkountouras commented 1 month ago

I tried ]up and now the optimization res = solve(op, Adam(5e-3); maxiters = 100, callback = cb) goes through (less iterations since we're interested in the shapes here).

I don't get any instabilities or warnings about them. There are 3 unrelated warnings when setting up the problems.

The system with all neural network parameters set to 0 solves fine. Obviously the result is not correct.

Now the line res_sol = solve(res_prob, Rodas4(), saveat = sol_ref.t) gets stuck indefinitely evaluating. There's no warning or other output.

ChrisRackauckas commented 1 month ago

instrument the solver and check its stepping makes sense.

SebastianM-C commented 1 month ago

I updated my packages again (maybe I had some weird local state) and I can now reproduce your issue. The issue seems to be that

res_p = SciMLStructures.replace(Tunable(), prob.p, res)

is creating a large MTKParameters object that maybe blows up in compile time. The fix is just using .u to get the underlying array

res_p = SciMLStructures.replace(Tunable(), prob.p, res.u)

@ggkountouras Let me know if this fixes your issue.

ggkountouras commented 1 month ago

Thanks, that fixes it.

The untrained NN makes the simulation go unstable, but that's expected.