SciML / ModelingToolkit.jl

An acausal modeling framework for automatically parallelized scientific machine learning (SciML) in Julia. A computer algebra system for integrated symbolics for physics-informed machine learning and automated transformations of differential equations
https://mtk.sciml.ai/dev/
Other
1.43k stars 209 forks source link

Expression order bug in the codegen for `initializeprobmap` #3109

Closed SebastianM-C closed 1 month ago

SebastianM-C commented 1 month ago

Describe the bug 🐞

I'm seeing a bug in the order of the expressions in initializeprobmap with MTKNN. Currently MTKNN forces the defaults for the nn inputs to 0s due to some initialization warning:

┌ Warning: Internal error: Variable (nn₊input₊u(t))[2] was marked as being in 0 ~ (LuxCore.stateless_apply(Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(2 => 5, tanh), layer_2 = Dense(5 => 5, tanh), layer_3 = Dense(5 => 2)), nothing), nn₊input₊u(t), convert(nn₊T, nn₊p)))[2] - (nn₊output₊u(t))[2], but was actually zero
└ @ ModelingToolkit.StructuralTransformations ~/.julia/dev/ModelingToolkit/src/structural_transformation/utils.jl:237

If I don't provide the defaults, then I hit the OverrideInit dispatch for _initialize_dae! and the generated code for the initializeprobmap getu is

julia> prob.f.initializeprobmap.obsfn.var"#515#_fn"
RuntimeGeneratedFunction(#=in ModelingToolkit=#, #=using ModelingToolkit=#, :((var"##arg#5964805160111424296", ___mtkparameters___)->begin
          #= /home/sebastian/.julia/packages/SymbolicUtils/ij6YM/src/code.jl:385 =#
          #= /home/sebastian/.julia/packages/SymbolicUtils/ij6YM/src/code.jl:386 =#
          #= /home/sebastian/.julia/packages/SymbolicUtils/ij6YM/src/code.jl:387 =#
          begin
              var"##arg#1300299927463060512" = ___mtkparameters___[1]
              var"##arg#5212838221877633925" = ___mtkparameters___[2]
              var"##arg#14786571317958501831" = ___mtkparameters___[3]
              begin
                  var"nn₊p[1]" = var"##arg#1300299927463060512"[1]
                  var"nn₊p[2]" = var"##arg#1300299927463060512"[2]
                  var"nn₊p[3]" = var"##arg#1300299927463060512"[3]
                  var"nn₊p[4]" = var"##arg#1300299927463060512"[4]
                  var"nn₊p[5]" = var"##arg#1300299927463060512"[5]
                  var"nn₊p[6]" = var"##arg#1300299927463060512"[6]
                  var"nn₊p[7]" = var"##arg#1300299927463060512"[7]
                  var"nn₊p[8]" = var"##arg#1300299927463060512"[8]
                  var"nn₊p[9]" = var"##arg#1300299927463060512"[9]
                  var"nn₊p[10]" = var"##arg#1300299927463060512"[10]
                  var"nn₊p[11]" = var"##arg#1300299927463060512"[11]
                  var"nn₊p[12]" = var"##arg#1300299927463060512"[12]
                  var"nn₊p[13]" = var"##arg#1300299927463060512"[13]
                  var"nn₊p[14]" = var"##arg#1300299927463060512"[14]
                  var"nn₊p[15]" = var"##arg#1300299927463060512"[15]
                  var"nn₊p[16]" = var"##arg#1300299927463060512"[16]
                  var"nn₊p[17]" = var"##arg#1300299927463060512"[17]
                  var"nn₊p[18]" = var"##arg#1300299927463060512"[18]
                  var"nn₊p[19]" = var"##arg#1300299927463060512"[19]
                  var"nn₊p[20]" = var"##arg#1300299927463060512"[20]
                  var"nn₊p[21]" = var"##arg#1300299927463060512"[21]
                  var"nn₊p[22]" = var"##arg#1300299927463060512"[22]
                  var"nn₊p[23]" = var"##arg#1300299927463060512"[23]
                  var"nn₊p[24]" = var"##arg#1300299927463060512"[24]
                  var"nn₊p[25]" = var"##arg#1300299927463060512"[25]
                  var"nn₊p[26]" = var"##arg#1300299927463060512"[26]
                  var"nn₊p[27]" = var"##arg#1300299927463060512"[27]
                  var"nn₊p[28]" = var"##arg#1300299927463060512"[28]
                  var"nn₊p[29]" = var"##arg#1300299927463060512"[29]
                  var"nn₊p[30]" = var"##arg#1300299927463060512"[30]
                  var"nn₊p[31]" = var"##arg#1300299927463060512"[31]
                  var"nn₊p[32]" = var"##arg#1300299927463060512"[32]
                  var"nn₊p[33]" = var"##arg#1300299927463060512"[33]
                  var"nn₊p[34]" = var"##arg#1300299927463060512"[34]
                  var"nn₊p[35]" = var"##arg#1300299927463060512"[35]
                  var"nn₊p[36]" = var"##arg#1300299927463060512"[36]
                  var"nn₊p[37]" = var"##arg#1300299927463060512"[37]
                  var"nn₊p[38]" = var"##arg#1300299927463060512"[38]
                  var"nn₊p[39]" = var"##arg#1300299927463060512"[39]
                  var"nn₊p[40]" = var"##arg#1300299927463060512"[40]
                  var"nn₊p[41]" = var"##arg#1300299927463060512"[41]
                  var"nn₊p[42]" = var"##arg#1300299927463060512"[42]
                  var"nn₊p[43]" = var"##arg#1300299927463060512"[43]
                  var"nn₊p[44]" = var"##arg#1300299927463060512"[44]
                  var"nn₊p[45]" = var"##arg#1300299927463060512"[45]
                  var"nn₊p[46]" = var"##arg#1300299927463060512"[46]
                  var"nn₊p[47]" = var"##arg#1300299927463060512"[47]
                  var"nn₊p[48]" = var"##arg#1300299927463060512"[48]
                  var"nn₊p[49]" = var"##arg#1300299927463060512"[49]
                  var"nn₊p[50]" = var"##arg#1300299927463060512"[50]
                  var"nn₊p[51]" = var"##arg#1300299927463060512"[51]
                  var"nn₊p[52]" = var"##arg#1300299927463060512"[52]
                  var"nn₊p[53]" = var"##arg#1300299927463060512"[53]
                  var"nn₊p[54]" = var"##arg#1300299927463060512"[54]
                  var"nn₊p[55]" = var"##arg#1300299927463060512"[55]
                  var"nn₊p[56]" = var"##arg#1300299927463060512"[56]
                  var"nn₊p[57]" = var"##arg#1300299927463060512"[57]
                  t = var"##arg#1300299927463060512"[58]
                  lotka₊δ = var"##arg#5212838221877633925"[1]
                  lotka₊α = var"##arg#5212838221877633925"[2]
                  nn₊T = var"##arg#14786571317958501831"[1]
                  begin
                      nn₊p = reshape(view(var"##arg#1300299927463060512", 1:57), (57,))
                      begin
                          begin
                              var"lotka₊x(t)" = 3.1
                              var"lotka₊y(t)" = 1.5
                              var"(nn₊output₊u(t))[1]" = (getindex)((LuxCore.stateless_apply)(Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(2 => 5, tanh), layer_2 = Dense(5 => 5, tanh), layer_3 = Dense(5 => 2)), nothing), begin
                                              #= /home/sebastian/.julia/packages/SymbolicUtils/ij6YM/src/code.jl:480 =#
                                              (SymbolicUtils.Code.create_array)(OffsetArrays.OffsetVector{SymbolicUtils.BasicSymbolic{Real}, Vector{SymbolicUtils.BasicSymbolic{Real}}}, nothing, Val{1}(), Val{(2,)}(), var"(nn₊input₊u(t))[1]", var"(nn₊input₊u(t))[2]")
                                          end, (convert)(nn₊T, nn₊p)), 1)
                              var"(nn₊output₊u(t))[2]" = (getindex)((LuxCore.stateless_apply)(Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(2 => 5, tanh), layer_2 = Dense(5 => 5, tanh), layer_3 = Dense(5 => 2)), nothing), begin
                                              #= /home/sebastian/.julia/packages/SymbolicUtils/ij6YM/src/code.jl:480 =#
                                              (SymbolicUtils.Code.create_array)(OffsetArrays.OffsetVector{SymbolicUtils.BasicSymbolic{Real}, Vector{SymbolicUtils.BasicSymbolic{Real}}}, nothing, Val{1}(), Val{(2,)}(), var"(nn₊input₊u(t))[1]", var"(nn₊input₊u(t))[2]")
                                          end, (convert)(nn₊T, nn₊p)), 2)
                              var"(nn₊input₊u(t))[1]" = var"lotka₊x(t)"
                              var"(nn₊input₊u(t))[2]" = var"lotka₊y(t)"
                              begin
                                  #= /home/sebastian/.julia/packages/SymbolicUtils/ij6YM/src/code.jl:480 =#
                                  (SymbolicUtils.Code.create_array)(Array, nothing, Val{1}(), Val{(4,)}(), var"lotka₊x(t)", var"lotka₊y(t)", var"(nn₊input₊u(t))[1]", var"(nn₊input₊u(t))[2]")
                              end
                          end
                      end
                  end
              end
          end
      end))

Note how var"(nn₊input₊u(t))[1]", var"(nn₊input₊u(t))[2]") are used before they are declared.

Expected behavior

solve working

Minimal Reproducible Example 👇

Without MRE, we would only be able to help you to a limited extent, and attention to the issue would be limited. to know more about MRE refer to wikipedia and stackoverflow.

using Test
using ModelingToolkitNeuralNets
using ModelingToolkit
using ModelingToolkitStandardLibrary.Blocks
using OrdinaryDiffEqRosenbrock, OrdinaryDiffEqNonlinearSolve
using SymbolicIndexingInterface
using StableRNGs

function lotka_ude()
    @variables t x(t) = 3.1 y(t) = 1.5
    @parameters α = 1.3 [tunable = false] δ = 1.8 [tunable = false]
    Dt = ModelingToolkit.D_nounits
    @named nn_in = RealInputArray(nin=2)
    @named nn_out = RealOutputArray(nout=2)

    eqs = [
        Dt(x) ~ α * x + nn_in.u[1],
        Dt(y) ~ -δ * y + nn_in.u[2],
        nn_out.u[1] ~ x,
        nn_out.u[2] ~ y
    ]
    return ODESystem(
        eqs, ModelingToolkit.t_nounits, name=:lotka, systems=[nn_in, nn_out])
end

function lotka_true()
    @variables t x(t) = 3.1 y(t) = 1.5
    @parameters α = 1.3 β = 0.9 γ = 0.8 δ = 1.8
    Dt = ModelingToolkit.D_nounits

    eqs = [
        Dt(x) ~ α * x - β * x * y,
        Dt(y) ~ -δ * y + δ * x * y
    ]
    return ODESystem(eqs, ModelingToolkit.t_nounits, name=:lotka_true)
end

model = lotka_ude()

chain = multi_layer_feed_forward(2, 2)
@named nn = NeuralNetworkBlock(2, 2; chain, rng=StableRNG(42))

eqs = [connect(model.nn_in, nn.output)
    connect(model.nn_out, nn.input)]

ude_sys = complete(ODESystem(
    eqs, ModelingToolkit.t_nounits, systems=[model, nn],
    name=:ude_sys,
    # defaults=[nn.input.u => [0.0, 0.0]]
))

sys = structural_simplify(ude_sys)

prob = ODEProblem{true,SciMLBase.FullSpecialize}(sys, [], (0, 1.0), [])

iprob = ModelingToolkit.InitializationProblem(sys, 0.0)
solve(iprob)

solve(prob, Rodas5P())

Error & Stacktrace ⚠️

ERROR: UndefVarError: `(nn₊input₊u(t))[1]` not defined in local scope
Suggestion: check for an assignment to a local variable that shadows a global of the same name.
Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/SymbolicUtils/ij6YM/src/code.jl:480 [inlined]
  [2] macro expansion
    @ ~/.julia/packages/RuntimeGeneratedFunctions/M9ZX8/src/RuntimeGeneratedFunctions.jl:163 [inlined]
  [3] macro expansion
    @ ./none:0 [inlined]
  [4] generated_callfunc
    @ ./none:0 [inlined]
  [5] (::RuntimeGeneratedFunctions.RuntimeGeneratedFunction{…})(::Vector{…}, ::MTKParameters{…})
    @ RuntimeGeneratedFunctions ~/.julia/packages/RuntimeGeneratedFunctions/M9ZX8/src/RuntimeGeneratedFunctions.jl:150
  [6] (::ModelingToolkit.var"#fn2#272"{…})(u::Vector{…}, p::MTKParameters{…})
    @ ModelingToolkit ~/.julia/dev/ModelingToolkit/src/systems/abstractsystem.jl:840
  [7] (::SymbolicIndexingInterface.TimeIndependentObservedFunction{…})(::NotTimeseries, prob::SciMLBase.NonlinearSolution{…})
    @ SymbolicIndexingInterface ~/.julia/packages/SymbolicIndexingInterface/cwAFH/src/state_indexing.jl:142
  [8] (::SymbolicIndexingInterface.TimeIndependentObservedFunction{…})(prob::SciMLBase.NonlinearSolution{…})
    @ SymbolicIndexingInterface ~/.julia/packages/SymbolicIndexingInterface/cwAFH/src/value_provider_interface.jl:166
  [9] _initialize_dae!(integrator::OrdinaryDiffEqCore.ODEIntegrator{…}, prob::ODEProblem{…}, alg::OrdinaryDiffEqCore.OverrideInit{…}, isinplace::Val{…})
    @ OrdinaryDiffEqCore ~/.julia/packages/OrdinaryDiffEqCore/HwWWN/src/initialize_dae.jl:174
 [10] _initialize_dae!(integrator::OrdinaryDiffEqCore.ODEIntegrator{…}, prob::ODEProblem{…}, alg::OrdinaryDiffEqCore.DefaultInit, x::Val{…})
    @ OrdinaryDiffEqCore ~/.julia/packages/OrdinaryDiffEqCore/HwWWN/src/initialize_dae.jl:60
 [11] initialize_dae!(integrator::OrdinaryDiffEqCore.ODEIntegrator{…}, initializealg::OrdinaryDiffEqCore.DefaultInit)
    @ OrdinaryDiffEqCore ~/.julia/packages/OrdinaryDiffEqCore/HwWWN/src/initialize_dae.jl:50
 [12] __init(prob::ODEProblem{…}, alg::Rodas5P{…}, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}, recompile::Type{…}; saveat::Tuple{}, tstops::Tuple{}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_start::Bool, save_end::Nothing, callback::Nothing, dense::Bool, calck::Bool, dt::Float64, dtmin::Float64, dtmax::Float64, force_dtmin::Bool, adaptive::Bool, gamma::Rational{…}, abstol::Nothing, reltol::Nothing, qmin::Rational{…}, qmax::Int64, qsteady_min::Int64, qsteady_max::Rational{…}, beta1::Nothing, beta2::Nothing, qoldinit::Rational{…}, controller::Nothing, fullnormalize::Bool, failfactor::Int64, maxiters::Int64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), internalopnorm::typeof(LinearAlgebra.opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), progress_id::Symbol, userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias_u0::Bool, alias_du0::Bool, initializealg::OrdinaryDiffEqCore.DefaultInit, kwargs::@Kwargs{})
    @ OrdinaryDiffEqCore ~/.julia/packages/OrdinaryDiffEqCore/HwWWN/src/solve.jl:503
 [13] __init (repeats 5 times)
    @ ~/.julia/packages/OrdinaryDiffEqCore/HwWWN/src/solve.jl:11 [inlined]
 [14] __solve(::ODEProblem{…}, ::Rodas5P{…}; kwargs::@Kwargs{})
    @ OrdinaryDiffEqCore ~/.julia/packages/OrdinaryDiffEqCore/HwWWN/src/solve.jl:6
 [15] __solve
    @ ~/.julia/packages/OrdinaryDiffEqCore/HwWWN/src/solve.jl:1 [inlined]
 [16] #solve_call#44
    @ ~/.julia/packages/DiffEqBase/uqSeD/src/solve.jl:612 [inlined]
 [17] solve_call(_prob::ODEProblem{…}, args::Rodas5P{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/uqSeD/src/solve.jl:569
 [18] solve_up(prob::ODEProblem{…}, sensealg::Nothing, u0::Vector{…}, p::MTKParameters{…}, args::Rodas5P{…}; kwargs::@Kwargs{})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/uqSeD/src/solve.jl:1092
 [19] solve_up
    @ ~/.julia/packages/DiffEqBase/uqSeD/src/solve.jl:1078 [inlined]
 [20] solve(prob::ODEProblem{…}, args::Rodas5P{…}; sensealg::Nothing, u0::Nothing, p::Nothing, wrap::Val{…}, kwargs::@Kwargs{})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/uqSeD/src/solve.jl:1015
 [21] solve(prob::ODEProblem{…}, args::Rodas5P{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/uqSeD/src/solve.jl:1005
 [22] top-level scope

Environment (please complete the following information):

Additional context

Add any other context about the problem here.