UDE as hierarchical component #36

Closed ggkountouras closed 1 month ago

ggkountouras commented 2 months ago

Describe the bug 🐞

When adapting the Friction example to a Diode-Capacitor model, I ran into ERROR: MethodError: Cannot convert an object of type SymbolicUtils.BasicSymbolic{Real} to an object of type Float64.

Expected behavior

It should work like the friction example.

Minimal Reproducible Example 👇


```julia using ModelingToolkitNeuralNets using ModelingToolkit import ModelingToolkit.t_nounits as t import ModelingToolkit.D_nounits as Dt using ModelingToolkitStandardLibrary.Electrical using ModelingToolkitStandardLibrary.Blocks using OrdinaryDiffEq using Optimization using OptimizationOptimisers: Adam using SciMLStructures using SciMLStructures: Tunable using SymbolicIndexingInterface using StableRNGs using Lux using Plots function shockley_equation(v, Is, eta, Vt) Is*expm1(v/(eta*Vt)) end function ShockleyDiode(;name, Is = 1.0e-15, eta = 1.0, Vt = 26.0e-3) @named oneport = OnePort() @unpack v, i = oneport ps = @parameters Is=Is eta=eta Vt=Vt eqs = [ i ~ shockley_equation(v, Is, eta, Vt) ] extend(ODESystem(eqs, t, [], ps; name=name), oneport) end @named Vin = Voltage() @named Vin_V = Sine(amplitude=1.0, frequency=5.0) @named D = ShockleyDiode() @named C = Capacitor(C=1.0e-4) @named GND = Ground() diode_capacitor_eqs = [ connect(Vin_V.output, Vin.V) connect(Vin.p, D.p) connect(D.n, C.p) connect(Vin.n, C.n, GND.g) ] @named diode_capacitor_model_true = ODESystem(diode_capacitor_eqs, t, systems=[Vin, Vin_V, D, C, GND]) sys_true = structural_simplify(diode_capacitor_model_true) u0 = fill(0.0, length(unknowns(sys_true))) prob_true = ODEProblem(sys_true, u0, (0, 1.0), []) sol_ref = solve(prob_true, Rodas4(), abstol=1e-12, reltol=1e-7) # works fine plot(sol_ref, idxs=[Vin.v, C.v], xformatter=(_...) -> "") function diode_ude(;name) @named oneport = OnePort() @unpack v, i = oneport @named nn_in = RealInputArray(nin = 1) @named nn_out = RealOutputArray(nout = 1) chain = Lux.Chain( Lux.Dense(1 => 10, Lux.mish, use_bias = false), Lux.Dense(10 => 10, Lux.mish, use_bias = false), Lux.Dense(10 => 1, use_bias = false) ) @named nn = NeuralNetworkBlock(1, 1; chain = chain, rng = StableRNG(1111)) eqs = [ connect(nn_in, nn.output) connect(nn_out, nn.input) v ~ nn_out.u[1] i ~ nn_in.u[1] ] extend(ODESystem(eqs, t, [], []; name=name, systems=[nn_in, nn_out, nn]), oneport) end @named Vin2 = Voltage() @named Vin_V2 = Sine(amplitude=1.0, frequency=5.0) @named D2 = diode_ude(name=:diode_ude) @named C2 = Capacitor(C=1.0e-4) @named GND2 = Ground() diode_capacitor_eqs2 = [ connect(Vin_V2.output, Vin2.V) connect(Vin2.p, D2.p) connect(D2.n, C2.p) connect(Vin2.n, C2.n, GND2.g) ] @named diode_capacitor_model_ude = ODESystem(diode_capacitor_eqs2, t, systems=[Vin2, Vin_V2, D2, C2, GND2]) ude_sys = complete(diode_capacitor_model_ude) sys = structural_simplify(ude_sys) function loss(x, (prob, sol_ref, get_vars, get_refs)) new_p = SciMLStructures.replace(Tunable(), prob.p, x) new_prob = remake(prob, p = new_p, u0 = eltype(x).(prob.u0)) ts = sol_ref.t new_sol = solve(new_prob, Rodas4(), saveat = ts, abstol = 1e-12, reltol = 1e-7) loss = zero(eltype(x)) for i in eachindex(new_sol.u) loss += sum(abs2.(get_vars(new_sol, i) .- get_refs(sol_ref, i))) end if SciMLBase.successful_retcode(new_sol) loss else Inf end end of = OptimizationFunction{true}(loss, AutoForwardDiff()) prob = ODEProblem(sys, [0.0, 0.0], (0, 10.0), []) get_vars = getu(sys, [sys.C2.v]) get_refs = getu(sys_true, [sys_true.C.v]) x0 = reduce(vcat, getindex.((default_values(sys),), tunable_parameters(sys))) cb = (opt_state, loss) -> begin @info "step $(opt_state.iter), loss: $loss" return false end op = OptimizationProblem(of, x0, (prob, sol_ref, get_vars, get_refs)) res = solve(op, Adam(5e-3); maxiters = 10000, callback = cb) # this fails res_p = SciMLStructures.replace(Tunable(), prob.p, res) res_prob = remake(prob, p = res_p) res_sol = solve(res_prob, Rodas4(), saveat = sol_ref.t) initial_sol = solve(prob, Rodas4(), saveat = sol_ref.t) ```

Error & Stacktrace ⚠️


```julia ERROR: MethodError: Cannot `convert` an object of type SymbolicUtils.BasicSymbolic{Real} to an object of type Float64 Closest candidates are: convert(::Type{Float64}, ::Measures.AbsoluteLength) @ Measures ~/.julia/packages/Measures/PKOxJ/src/length.jl:12 convert(::Type{T}, ::DualNumbers.Dual) where T<:Union{Real, Complex} @ DualNumbers ~/.julia/packages/DualNumbers/5knFX/src/dual.jl:24 convert(::Type{T}, ::VectorizationBase.LazyMulAdd{M, O, I}) where {M, O, I, T<:Number} @ VectorizationBase ~/.julia/packages/VectorizationBase/jSp7w/src/lazymul.jl:25 ... Stacktrace: [1] setindex!(A::Vector{Float64}, x::SymbolicUtils.BasicSymbolic{Real}, i1::Int64) @ Base ./array.jl:1021 [2] macro expansion @ ~/.julia/packages/SymbolicUtils/0opve/src/code.jl:418 [inlined] [3] macro expansion @ ~/.julia/packages/Symbolics/PAFGz/src/build_function.jl:546 [inlined] [4] macro expansion @ ~/.julia/packages/SymbolicUtils/0opve/src/code.jl:375 [inlined] [5] macro expansion @ ~/.julia/packages/RuntimeGeneratedFunctions/M9ZX8/src/RuntimeGeneratedFunctions.jl:163 [inlined] [6] macro expansion @ ./none:0 [inlined] [7] generated_callfunc @ ./none:0 [inlined] [8] (::RuntimeGeneratedFunctions.RuntimeGeneratedFunction{…})(::Vector{…}, ::Nothing, ::Vector{…}, ::Vector{…}, ::Vector{…}) @ RuntimeGeneratedFunctions ~/.julia/packages/RuntimeGeneratedFunctions/M9ZX8/src/RuntimeGeneratedFunctions.jl:150 [9] (::ModelingToolkit.var"#f#564"{…})(du::Vector{…}, u::Nothing, p::ModelingToolkit.MTKParameters{…}) @ ModelingToolkit ~/.julia/packages/ModelingToolkit/353ne/src/systems/nonlinear/nonlinearsystem.jl:292 [10] NonlinearFunction @ ~/.julia/packages/SciMLBase/sakPO/src/scimlfunctions.jl:2297 [inlined] [11] #build_null_solution#50 @ ~/.julia/packages/DiffEqBase/DS1sd/src/solve.jl:701 [inlined] [12] build_null_solution @ ~/.julia/packages/DiffEqBase/DS1sd/src/solve.jl:696 [inlined] [13] #solve_call#44 @ ~/.julia/packages/DiffEqBase/DS1sd/src/solve.jl:604 [inlined] [14] solve_call @ ~/.julia/packages/DiffEqBase/DS1sd/src/solve.jl:569 [inlined] [15] #solve_up#53 @ ~/.julia/packages/DiffEqBase/DS1sd/src/solve.jl:1072 [inlined] [16] solve_up @ ~/.julia/packages/DiffEqBase/DS1sd/src/solve.jl:1066 [inlined] [17] #solve#51 @ ~/.julia/packages/DiffEqBase/DS1sd/src/solve.jl:1003 [inlined] [18] solve @ ~/.julia/packages/DiffEqBase/DS1sd/src/solve.jl:993 [inlined] [19] _initialize_dae!(integrator::OrdinaryDiffEq.ODEIntegrator{…}, prob::ODEProblem{…}, alg::OrdinaryDiffEq.OverrideInit{…}, isinplace::Val{…}) @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/Knuk0/src/initialize_dae.jl:158 [20] _initialize_dae! ```



```julia Status `~/Documents/dev/julia-MTK-NN/Project.toml` [0c46a032] DifferentialEquations v7.13.0 ⌃ [b2108857] Lux v0.5.56 ⌃ [961ee093] ModelingToolkit v9.19.0 ⌃ [f162e290] ModelingToolkitNeuralNets v1.0.2 ⌃ [16a59e39] ModelingToolkitStandardLibrary v2.7.2 ⌃ [7f7a1694] Optimization v3.26.1 [42dfb2eb] OptimizationOptimisers v0.2.1 ⌃ [1dea7af3] OrdinaryDiffEq v6.84.0 ⌃ [91a5bcdd] Plots v1.40.4 ⌃ [53ae85a6] SciMLStructures v1.3.0 [860ef19b] StableRNGs v1.0.2 ⌃ [c3572dad] Sundials v4.24.0 ⌃ [2efcf032] SymbolicIndexingInterface v0.3.22 Info Packages marked with ⌃ have new versions available and may be upgradable. ```


```


```julia Julia Version 1.10.4 Commit 48d4fd48430 (2024-06-04 10:41 UTC) Build Info: Official release Platform Info: OS: macOS (arm64-apple-darwin22.4.0) CPU: 10 × Apple M1 Max WORD_SIZE: 64 LIBM: libopenlibm LLVM: libLLVM-15.0.7 (ORCJIT, apple-m1) Threads: 1 default, 0 interactive, 1 GC (on 8 virtual cores) Environment: JULIA_EDITOR = code JULIA_NUM_THREADS = ```

SebastianM-C commented 2 months ago

It looks like you have quite old versions of the packages. Can you try an ]up? I also tried to reproduce the error with newer package versions and I can't reproduce it, but I'm getting instabilities in the optimization function when using the initial condition.

julia> of(x0, (prob, sol_ref, get_vars, get_refs))
┌ Warning: At t=0.0, dt was forced below floating point epsilon 5.0e-324, and step error estimate = NaN. Aborting. There is either an error in your model specification or the true solution is unstable (or the true solution can not be represented in the precision of Float64).
└ @ SciMLBase ~/.julia/packages/SciMLBase/HReyK/src/integrator_interface.jl:600

If you build a parameter vector that zeros out the neural network parameters, you can check if the instability is related to the neural network or not. In this case, I tried that and it seems that the system is unstable even with the neural network not doing anything.

While the neural network could theoretically have parameters that make the model stable, it's very hard to train it if the starting point is unstable, since that's infinite loss and doesn't inform the optimizer in any way. In order to have a more manageable problem, the starting point should have some non-fininte loss value.

ggkountouras commented 1 month ago

I tried ]up and now the optimization res = solve(op, Adam(5e-3); maxiters = 100, callback = cb) goes through (less iterations since we're interested in the shapes here).

I don't get any instabilities or warnings about them. There are 3 unrelated warnings when setting up the problems.

The system with all neural network parameters set to 0 solves fine. Obviously the result is not correct.

Now the line res_sol = solve(res_prob, Rodas4(), saveat = sol_ref.t) gets stuck indefinitely evaluating. There's no warning or other output.

ChrisRackauckas commented 1 month ago

instrument the solver and check its stepping makes sense.

SebastianM-C commented 1 month ago

I updated my packages again (maybe I had some weird local state) and I can now reproduce your issue. The issue seems to be that

res_p = SciMLStructures.replace(Tunable(), prob.p, res)

is creating a large MTKParameters object that maybe blows up in compile time. The fix is just using .u to get the underlying array

res_p = SciMLStructures.replace(Tunable(), prob.p, res.u)

@ggkountouras Let me know if this fixes your issue.

ggkountouras commented 1 month ago

Thanks, that fixes it.

The untrained NN makes the simulation go unstable, but that's expected.